From a7c9e7fc9d32d41aa0a187fbe133189be8d8f80d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 28 Jun 2024 13:29:19 -0400 Subject: [PATCH 01/17] Prototype user defined sql planner might look like --- datafusion/functions-array/src/lib.rs | 1 - datafusion/sql/src/expr/mod.rs | 133 ++++++++++++++++---------- datafusion/sql/src/planner.rs | 33 +++++++ 3 files changed, 113 insertions(+), 54 deletions(-) diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index b2fcb5717b3a..2e812a3df5fe 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -51,7 +51,6 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; - use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::ScalarUDF; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a8af37ee6a37..42a742782c9a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -19,6 +19,7 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; use datafusion_common::utils::list_ndims; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; +use std::sync::Arc; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, @@ -28,10 +29,10 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, Like, Literal, Operator, TryCast, + GetFieldAccess, Like, Literal, Operator, ScalarUDF, TryCast, }; -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use crate::planner::{ContextProvider, PlannerContext, SqlToRel, UserDefinedPlanner}; mod binary_op; mod function; @@ -52,7 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(Operator), + Operator(sqlparser::ast::BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr @@ -69,7 +70,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::BinaryOp { left, op, right } => { // Note the order that we push the entries to the stack // is important. We want to visit the left node first. - let op = self.parse_sql_binary_op(op)?; stack.push(StackEntry::Operator(op)); stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); @@ -100,63 +100,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn build_logical_expr( &self, - op: Operator, + op: sqlparser::ast::BinaryOperator, left: Expr, right: Expr, schema: &DFSchema, ) -> Result { - // Rewrite string concat operator to function based on types - // if we get list || list then we rewrite it to array_concat() - // if we get list || non-list then we rewrite it to array_append() - // if we get non-list || list then we rewrite it to array_prepend() - // if we get string || string then we rewrite it to concat() - if op == Operator::StringConcat { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); - - // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. - // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. - if left_list_ndims + right_list_ndims == 0 { - // TODO: concat function ignore null, but string concat takes null into consideration - // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` - } else if left_list_ndims == right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_concat") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_concat not found"); - } - } else if left_list_ndims > right_list_ndims { - if let Some(udf) = self.context_provider.get_function_meta("array_append") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_append not found"); - } - } else if left_list_ndims < right_list_ndims { - if let Some(udf) = - self.context_provider.get_function_meta("array_prepend") - { - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![left, right], - ))); - } else { - return internal_err!("array_append not found"); - } + // try extension planers + for planner in self.planners.iter() { + if let Some(expr) = + planner.plan_binary_op(op.clone(), left.clone(), right.clone(), schema)? + { + return Ok(expr); } } + + // by default, convert to datafusion operator + Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), - op, + self.parse_sql_binary_op(op)?, Box::new(right), ))) } @@ -1017,6 +979,71 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } +pub struct ArrayFunctionPlanner { + array_concat: Arc, + array_append: Arc, + array_prepend: Arc, +} + +impl ArrayFunctionPlanner { + pub fn try_new(context_provider: &dyn ContextProvider) -> Result { + let Some(array_concat) = context_provider.get_function_meta("array_concat") + else { + return internal_err!("array_concat not found"); + }; + let Some(array_append) = context_provider.get_function_meta("array_append") + else { + return internal_err!("array_append not found"); + }; + let Some(array_prepend) = context_provider.get_function_meta("array_prepend") + else { + return internal_err!("array_prepend not found"); + }; + + Ok(Self { + array_concat, + array_append, + array_prepend, + }) + } +} +impl UserDefinedPlanner for ArrayFunctionPlanner { + fn plan_binary_op( + &self, + op: sqlparser::ast::BinaryOperator, + left: Expr, + right: Expr, + schema: &DFSchema, + ) -> Result> { + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + if op == sqlparser::ast::BinaryOperator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + return Ok(Some(self.array_concat.call(vec![left, right]))); + } else if left_list_ndims > right_list_ndims { + return Ok(Some(self.array_append.call(vec![left, right]))); + } else if left_list_ndims < right_list_ndims { + return Ok(Some(self.array_prepend.call(vec![left, right]))); + } + } + + Ok(None) + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 00f221200624..06718032f372 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -31,6 +31,7 @@ use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; +use crate::expr::ArrayFunctionPlanner; use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ @@ -236,11 +237,28 @@ impl PlannerContext { } } +/// This trait allows users to customize the behavior of the SQL planner +pub trait UserDefinedPlanner { + /// Plan the binary operation between two expressions, return None if not possible + /// TODO make an API that avoids the need to clone the expressions + fn plan_binary_op( + &self, + _op: sqlparser::ast::BinaryOperator, + _left: Expr, + _right: Expr, + _schema: &DFSchema, + ) -> Result> { + Ok(None) + } +} + /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, + /// user defined planner extensions + pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -249,14 +267,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::new_with_options(context_provider, ParserOptions::default()) } + /// add an user defined planner + pub fn with_user_defined_planner( + mut self, + planner: Arc, + ) -> Self { + self.planners.push(planner); + self + } + /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; + let array_planner = + Arc::new(ArrayFunctionPlanner::try_new(context_provider).unwrap()) as _; + SqlToRel { context_provider, options, normalizer: IdentNormalizer::new(normalize), + planners: vec![], } + // todo put this somewhere else + .with_user_defined_planner(array_planner) } pub fn build_schema(&self, columns: Vec) -> Result { From 3362985a4ea9a84d733cacde4ca5ede233818ae9 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 09:08:51 +0800 Subject: [PATCH 02/17] at arrow Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/mod.rs | 104 +++++++++++++++++++++++++-------- datafusion/sql/src/planner.rs | 42 +++++++++++-- 2 files changed, 117 insertions(+), 29 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 42a742782c9a..7b010bc25946 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,6 +17,7 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; +use datafusion_common::exec_err; use datafusion_common::utils::list_ndims; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use std::sync::Arc; @@ -32,6 +33,7 @@ use datafusion_expr::{ GetFieldAccess, Like, Literal, Operator, ScalarUDF, TryCast, }; +use crate::planner::PlannerSimplifyResult; use crate::planner::{ContextProvider, PlannerContext, SqlToRel, UserDefinedPlanner}; mod binary_op; @@ -106,21 +108,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result { // try extension planers - for planner in self.planners.iter() { - if let Some(expr) = - planner.plan_binary_op(op.clone(), left.clone(), right.clone(), schema)? - { - return Ok(expr); + let mut binary_expr = crate::planner::BinaryExpr { op, left, right }; + let num_planners = self.planners.len(); + for (i, planner) in self.planners.iter().enumerate() { + match planner.plan_binary_op(binary_expr, schema)? { + PlannerSimplifyResult::Simplified(expr) => { + return Ok(expr); + } + PlannerSimplifyResult::OriginalBinaryExpr(expr) + if i + 1 == num_planners => + { + let crate::planner::BinaryExpr { op, left, right } = expr; + return Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + self.parse_sql_binary_op(op)?, + Box::new(right), + ))); + } + PlannerSimplifyResult::OriginalBinaryExpr(expr) => { + binary_expr = expr; + } + _ => { + return exec_err!( + "Unexpected result, do you expect to return OriginalBinaryExpr?" + ) + } } } - // by default, convert to datafusion operator - - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - self.parse_sql_binary_op(op)?, - Box::new(right), - ))) + internal_err!("Unexpect to reach here") } /// Generate a relational expression from a SQL expression @@ -983,6 +999,7 @@ pub struct ArrayFunctionPlanner { array_concat: Arc, array_append: Arc, array_prepend: Arc, + array_has_all: Arc, } impl ArrayFunctionPlanner { @@ -999,48 +1016,85 @@ impl ArrayFunctionPlanner { else { return internal_err!("array_prepend not found"); }; + let Some(array_has_all) = context_provider.get_function_meta("array_has_all") + else { + return internal_err!("array_prepend not found"); + }; Ok(Self { array_concat, array_append, array_prepend, + array_has_all, }) } } impl UserDefinedPlanner for ArrayFunctionPlanner { fn plan_binary_op( &self, - op: sqlparser::ast::BinaryOperator, - left: Expr, - right: Expr, + expr: crate::planner::BinaryExpr, schema: &DFSchema, - ) -> Result> { - // Rewrite string concat operator to function based on types - // if we get list || list then we rewrite it to array_concat() - // if we get list || non-list then we rewrite it to array_append() - // if we get non-list || list then we rewrite it to array_prepend() - // if we get string || string then we rewrite it to concat() + ) -> Result { + let crate::planner::BinaryExpr { op, left, right } = expr; + if op == sqlparser::ast::BinaryOperator::StringConcat { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; let left_list_ndims = list_ndims(&left_type); let right_list_ndims = list_ndims(&right_type); + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. if left_list_ndims + right_list_ndims == 0 { // TODO: concat function ignore null, but string concat takes null into consideration // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` } else if left_list_ndims == right_list_ndims { - return Ok(Some(self.array_concat.call(vec![left, right]))); + return Ok(PlannerSimplifyResult::Simplified( + self.array_concat.call(vec![left, right]), + )); } else if left_list_ndims > right_list_ndims { - return Ok(Some(self.array_append.call(vec![left, right]))); + return Ok(PlannerSimplifyResult::Simplified( + self.array_append.call(vec![left, right]), + )); } else if left_list_ndims < right_list_ndims { - return Ok(Some(self.array_prepend.call(vec![left, right]))); + return Ok(PlannerSimplifyResult::Simplified( + self.array_prepend.call(vec![left, right]), + )); + } + } else if matches!( + op, + sqlparser::ast::BinaryOperator::AtArrow + | sqlparser::ast::BinaryOperator::ArrowAt + ) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if op == sqlparser::ast::BinaryOperator::AtArrow { + // array1 @> array2 -> array_has_all(array1, array2) + return Ok(PlannerSimplifyResult::Simplified( + self.array_has_all.call(vec![left, right]), + )); + } else { + // array1 <@ array2 -> array_has_all(array2, array1) + return Ok(PlannerSimplifyResult::Simplified( + self.array_has_all.call(vec![right, left]), + )); + } } } - Ok(None) + Ok(PlannerSimplifyResult::OriginalBinaryExpr( + crate::planner::BinaryExpr { op, left, right }, + )) } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 06718032f372..4aa2cf025e7a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -25,7 +25,7 @@ use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::WindowUDF; +use datafusion_expr::{GetFieldAccess, WindowUDF}; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -243,15 +243,49 @@ pub trait UserDefinedPlanner { /// TODO make an API that avoids the need to clone the expressions fn plan_binary_op( &self, - _op: sqlparser::ast::BinaryOperator, - _left: Expr, - _right: Expr, + expr: BinaryExpr, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) + } + + fn plan_field_access( + &self, + _expr: FieldAccessExpr, _schema: &DFSchema, ) -> Result> { Ok(None) } } +pub struct BinaryExpr { + pub op: sqlparser::ast::BinaryOperator, + pub left: Expr, + pub right: Expr, +} + +pub struct FieldAccessExpr { + pub field_access: GetFieldAccess, + pub expr: Expr, +} + +pub enum PlannerSimplifyResult { + /// The function call was simplified to an entirely new Expr + Simplified(Expr), + /// the function call could not be simplified, and the arguments + /// are return unmodified. + OriginalBinaryExpr(BinaryExpr), + OriginalFieldAccessExpr(FieldAccessExpr), +} + +pub enum PlanSimplifyResult { + /// The function call was simplified to an entirely new Expr + Simplified(Expr), + /// the function call could not be simplified, and the arguments + /// are return unmodified. + Original(Vec), +} + /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, From 42f641409a4c66bce99982aacdee6ad559b8f4d6 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 10:13:03 +0800 Subject: [PATCH 03/17] get field Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/mod.rs | 196 ++++++++++++++++++--------------- datafusion/sql/src/planner.rs | 11 +- 2 files changed, 116 insertions(+), 91 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 7b010bc25946..307899681396 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -129,9 +129,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { binary_expr = expr; } _ => { - return exec_err!( - "Unexpected result, do you expect to return OriginalBinaryExpr?" - ) + return exec_err!("Unexpected result encountered. Did you expect an OriginalBinaryExpr?") } } } @@ -222,7 +220,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let get_field_access = match *subscript { + let field_access = match *subscript { Subscript::Index { index } => { // index can be a name, in which case it is a named field access match index { @@ -291,7 +289,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - self.plan_field_access(expr, get_field_access) + let mut field_access_expr = + crate::planner::FieldAccessExpr { expr, field_access }; + let num_planners = self.planners.len(); + for (i, planner) in self.planners.iter().enumerate() { + match planner.plan_field_access(field_access_expr, schema)? { + PlannerSimplifyResult::Simplified(expr) => { + return Ok(expr) + } + PlannerSimplifyResult::OriginalFieldAccessExpr(_) if i + 1 == num_planners => { + return internal_err!("Expected a simplified result, but none was found") + } + PlannerSimplifyResult::OriginalFieldAccessExpr(expr) => { + field_access_expr = expr; + } + _ => { + return exec_err!("Unexpected result encountered. Did you expect an OriginalFieldAccessExpr?") + } + } + } + + internal_err!("Unexpect to reach here") } SQLExpr::CompoundIdentifier(ids) => { @@ -626,36 +644,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - /// Simplifies an expression like `ARRAY_AGG(expr)[index]` to `NTH_VALUE(expr, index)` - /// - /// returns Some(Expr) if the expression was simplified, otherwise None - /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF - fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { - fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) - } - match expr { - Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { - let mut new_args = agg_func.args.clone(); - new_args.push(index.clone()); - Some(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, - new_args, - agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), - agg_func.null_treatment, - ), - )) - } - _ => None, - } - } - /// Parses a struct(..) expression fn parse_struct( &self, @@ -941,58 +929,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = vec![fullstr, substr]; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - - /// Given an expression and the field to access, creates a new expression for accessing that field - fn plan_field_access( - &self, - expr: Expr, - get_field_access: GetFieldAccess, - ) -> Result { - match get_field_access { - GetFieldAccess::NamedStructField { name } => { - if let Some(udf) = self.context_provider.get_function_meta("get_field") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, lit(name)], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[idx] ==> array_element(expr, idx) - GetFieldAccess::ListIndex { key } => { - // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - if let Some(simplified) = Self::simplify_array_index_expr(&expr, &key) { - Ok(simplified) - } else if let Some(udf) = - self.context_provider.get_function_meta("array_element") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *key], - ))) - } else { - internal_err!("get_field not found") - } - } - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - if let Some(udf) = self.context_provider.get_function_meta("array_slice") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![expr, *start, *stop, *stride], - ))) - } else { - internal_err!("array_slice not found") - } - } - } - } } pub struct ArrayFunctionPlanner { @@ -1098,6 +1034,92 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { } } +pub struct FieldAccessPlanner { + get_field: Arc, + array_element: Arc, + array_slice: Arc, +} + +impl FieldAccessPlanner { + pub fn try_new(context_provider: &dyn ContextProvider) -> Result { + let Some(get_field) = context_provider.get_function_meta("get_field") else { + return internal_err!("get_feild not found"); + }; + let Some(array_element) = context_provider.get_function_meta("array_element") + else { + return internal_err!("array_element not found"); + }; + let Some(array_slice) = context_provider.get_function_meta("array_slice") else { + return internal_err!("array_slice not found"); + }; + + Ok(Self { + get_field, + array_element, + array_slice, + }) + } +} + +impl UserDefinedPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: crate::planner::FieldAccessExpr, + _schema: &DFSchema, + ) -> Result { + let crate::planner::FieldAccessExpr { expr, field_access } = expr; + + match field_access { + // expr["field"] => get_field(expr, "field") + GetFieldAccess::NamedStructField { name } => { + Ok(PlannerSimplifyResult::Simplified( + self.get_field.call(vec![expr, lit(name)]), + )) + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { + Ok(PlannerSimplifyResult::Simplified(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new( + AggregateFunction::NthValue, + agg_func + .args + .into_iter() + .chain(std::iter::once(*index)) + .collect(), + agg_func.distinct, + agg_func.filter.clone(), + agg_func.order_by.clone(), + agg_func.null_treatment, + ), + ))) + } + _ => Ok(PlannerSimplifyResult::Simplified( + self.array_element.call(vec![expr, *index]), + )), + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => Ok(PlannerSimplifyResult::Simplified( + self.array_slice.call(vec![expr, *start, *stop, *stride]), + )), + } + } +} + +fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + agg_func.func_def + == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + AggregateFunction::ArrayAgg, + ) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 4aa2cf025e7a..deb8afa19144 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -31,7 +31,7 @@ use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; -use crate::expr::ArrayFunctionPlanner; +use crate::expr::{ArrayFunctionPlanner, FieldAccessPlanner}; use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ @@ -251,10 +251,10 @@ pub trait UserDefinedPlanner { fn plan_field_access( &self, - _expr: FieldAccessExpr, + expr: FieldAccessExpr, _schema: &DFSchema, - ) -> Result> { - Ok(None) + ) -> Result { + Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) } } @@ -315,6 +315,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let normalize = options.enable_ident_normalization; let array_planner = Arc::new(ArrayFunctionPlanner::try_new(context_provider).unwrap()) as _; + let field_access_planner = + Arc::new(FieldAccessPlanner::try_new(context_provider).unwrap()) as _; SqlToRel { context_provider, @@ -324,6 +326,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // todo put this somewhere else .with_user_defined_planner(array_planner) + .with_user_defined_planner(field_access_planner) } pub fn build_schema(&self, columns: Vec) -> Result { From 8ece394080eadc7c23bcbd11a17074d1544d5f6f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 10:19:43 +0800 Subject: [PATCH 04/17] cleanup Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 307899681396..0928d33c4511 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -954,7 +954,7 @@ impl ArrayFunctionPlanner { }; let Some(array_has_all) = context_provider.get_function_meta("array_has_all") else { - return internal_err!("array_prepend not found"); + return internal_err!("array_has_all not found"); }; Ok(Self { @@ -1090,8 +1090,8 @@ impl UserDefinedPlanner for FieldAccessPlanner { .chain(std::iter::once(*index)) .collect(), agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), + agg_func.filter, + agg_func.order_by, agg_func.null_treatment, ), ))) From d5d31893b08f2ffcb5fc51bba08d1d17131e0aac Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 10:47:16 +0800 Subject: [PATCH 05/17] plan array literal Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/mod.rs | 15 ++++++++++++++ datafusion/sql/src/expr/value.rs | 34 +++++++++++++++++++++++--------- datafusion/sql/src/planner.rs | 11 ++++++++++- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0928d33c4511..62582bf6b4b3 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -936,6 +936,7 @@ pub struct ArrayFunctionPlanner { array_append: Arc, array_prepend: Arc, array_has_all: Arc, + make_array: Arc, } impl ArrayFunctionPlanner { @@ -956,12 +957,16 @@ impl ArrayFunctionPlanner { else { return internal_err!("array_has_all not found"); }; + let Some(make_array) = context_provider.get_function_meta("make_array") else { + return internal_err!("make_array not found"); + }; Ok(Self { array_concat, array_append, array_prepend, array_has_all, + make_array, }) } } @@ -1032,6 +1037,16 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { crate::planner::BinaryExpr { op, left, right }, )) } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::Simplified( + self.make_array.call(exprs), + )) + } } pub struct FieldAccessPlanner { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index fa95fc2e051d..225e4736c621 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use crate::planner::{ContextProvider, PlannerContext, PlannerSimplifyResult, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, + ScalarValue, }; -use datafusion_expr::expr::{BinaryExpr, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; @@ -142,13 +143,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - if let Some(udf) = self.context_provider.get_function_meta("make_array") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(udf, values))) - } else { - not_impl_err!( - "array_expression featrue is disable, So should implement make_array UDF by yourself" - ) + let mut exprs = values; + let num_planners = self.planners.len(); + for (i, planner) in self.planners.iter().enumerate() { + match planner.plan_array_literal(exprs, schema)? { + PlannerSimplifyResult::Simplified(expr) => { + return Ok(expr); + } + PlannerSimplifyResult::OriginalArray(_) if i + 1 == num_planners => { + return internal_err!( + "Expected a simplified result, but none was found" + ) + } + PlannerSimplifyResult::OriginalArray(values) => exprs = values, + _ => { + return exec_err!( + "Unexpected result encountered. Did you expect an OriginalArray?" + ) + } + } } + + internal_err!("Unexpect to reach here") } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index deb8afa19144..7fbe01b664ed 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -240,7 +240,6 @@ impl PlannerContext { /// This trait allows users to customize the behavior of the SQL planner pub trait UserDefinedPlanner { /// Plan the binary operation between two expressions, return None if not possible - /// TODO make an API that avoids the need to clone the expressions fn plan_binary_op( &self, expr: BinaryExpr, @@ -249,6 +248,7 @@ pub trait UserDefinedPlanner { Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) } + /// Plan the field access expression, return None if not possible fn plan_field_access( &self, expr: FieldAccessExpr, @@ -256,6 +256,14 @@ pub trait UserDefinedPlanner { ) -> Result { Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::OriginalArray(exprs)) + } } pub struct BinaryExpr { @@ -276,6 +284,7 @@ pub enum PlannerSimplifyResult { /// are return unmodified. OriginalBinaryExpr(BinaryExpr), OriginalFieldAccessExpr(FieldAccessExpr), + OriginalArray(Vec), } pub enum PlanSimplifyResult { From cdf5342226f093589938c90fa712fbc19f23b32f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 13:44:51 +0800 Subject: [PATCH 06/17] move to functions-array Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 2 + .../core/src/execution/session_state.rs | 31 ++- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/planner.rs | 113 ++++++++ datafusion/functions-array/Cargo.toml | 2 + datafusion/functions-array/src/lib.rs | 1 + datafusion/functions-array/src/planner.rs | 235 +++++++++++++++++ datafusion/sql/src/expr/mod.rs | 248 ++---------------- datafusion/sql/src/expr/value.rs | 13 +- datafusion/sql/src/planner.rs | 175 ++++-------- 10 files changed, 454 insertions(+), 367 deletions(-) create mode 100644 datafusion/expr/src/planner.rs create mode 100644 datafusion/functions-array/src/planner.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 28312fee79a7..5292a0dbab1b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1319,9 +1319,11 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-sql", "itertools", "log", "paste", + "sqlparser", ] [[package]] diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 0b880ddbf81b..752762010410 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -67,6 +67,7 @@ use datafusion_expr::{ AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; +use datafusion_functions_array::planner::{ArrayFunctionPlanner, FieldAccessPlanner}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, @@ -605,7 +606,7 @@ impl SessionState { } } - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + let query = self.build_sql_query_planner(&provider); query.statement_to_plan(statement) } @@ -658,8 +659,7 @@ impl SessionState { tables: HashMap::new(), }; - let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - + let query = self.build_sql_query_planner(&provider); query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) } @@ -943,6 +943,31 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } + + fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S> + where + S: ContextProvider, + { + let query = SqlToRel::new_with_options(provider, self.get_parser_options()); + + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + { + let array_planner = + Arc::new(ArrayFunctionPlanner::try_new(provider).unwrap()) as _; + + let field_access_planner = + Arc::new(FieldAccessPlanner::try_new(provider).unwrap()) as _; + + query + .with_user_defined_planner(array_planner) + .with_user_defined_planner(field_access_planner) + } + #[cfg(not(feature = "array_expressions"))] + { + query + } + } } struct SessionContextProvider<'a> { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 89ee94f9f845..38f0617fc5fa 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -48,6 +48,7 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod planner; pub mod registry; pub mod simplify; pub mod sort_properties; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs new file mode 100644 index 000000000000..df0e156dee52 --- /dev/null +++ b/datafusion/expr/src/planner.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_common::{ + config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, + Result, TableReference, +}; + +use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; + +/// The ContextProvider trait allows the query planner to obtain meta-data about tables and +/// functions referenced in SQL statements +pub trait ContextProvider { + /// Getter for a datasource + fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + + /// Getter for a UDF description + fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDAF description + fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; + /// Getter for system/user-defined variable type + fn get_variable_type(&self, variable_names: &[String]) -> Option; + + /// Get configuration options + fn options(&self) -> &ConfigOptions; + + /// Get all user defined scalar function names + fn udf_names(&self) -> Vec; + + /// Get all user defined aggregate function names + fn udaf_names(&self) -> Vec; + + /// Get all user defined window function names + fn udwf_names(&self) -> Vec; +} + +/// This trait allows users to customize the behavior of the SQL planner +pub trait UserDefinedPlanner { + /// Plan the binary operation between two expressions, return None if not possible + fn plan_binary_op( + &self, + expr: BinaryExpr, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) + } + + /// Plan the field access expression, return None if not possible + fn plan_field_access( + &self, + expr: FieldAccessExpr, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::OriginalArray(exprs)) + } +} + +pub struct BinaryExpr { + pub op: sqlparser::ast::BinaryOperator, + pub left: Expr, + pub right: Expr, +} + +pub struct FieldAccessExpr { + pub field_access: GetFieldAccess, + pub expr: Expr, +} + +pub enum PlannerSimplifyResult { + /// The function call was simplified to an entirely new Expr + Simplified(Expr), + /// the function call could not be simplified, and the arguments + /// are return unmodified. + OriginalBinaryExpr(BinaryExpr), + OriginalFieldAccessExpr(FieldAccessExpr), + OriginalArray(Vec), +} diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index eb1ef9e03f31..784677e66462 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -49,9 +49,11 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-sql = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" +sqlparser = { workspace = true } [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 2c46d685abbe..814127be806b 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,6 +39,7 @@ pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod planner; pub mod position; pub mod range; pub mod remove; diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs new file mode 100644 index 000000000000..8dbfaba87436 --- /dev/null +++ b/datafusion/functions-array/src/planner.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{internal_err, utils::list_ndims, DFSchema, Result}; +use datafusion_expr::{ + lit, + planner::{ + BinaryExpr, ContextProvider, FieldAccessExpr, PlannerSimplifyResult, + UserDefinedPlanner, + }, + AggregateFunction, Expr, ExprSchemable, GetFieldAccess, ScalarUDF, +}; + +pub struct ArrayFunctionPlanner { + array_concat: Arc, + array_append: Arc, + array_prepend: Arc, + array_has_all: Arc, + make_array: Arc, +} + +impl ArrayFunctionPlanner { + pub fn try_new(context_provider: &dyn ContextProvider) -> Result { + let Some(array_concat) = context_provider.get_function_meta("array_concat") + else { + return internal_err!("array_concat not found"); + }; + let Some(array_append) = context_provider.get_function_meta("array_append") + else { + return internal_err!("array_append not found"); + }; + let Some(array_prepend) = context_provider.get_function_meta("array_prepend") + else { + return internal_err!("array_prepend not found"); + }; + let Some(array_has_all) = context_provider.get_function_meta("array_has_all") + else { + return internal_err!("array_has_all not found"); + }; + let Some(make_array) = context_provider.get_function_meta("make_array") else { + return internal_err!("make_array not found"); + }; + + Ok(Self { + array_concat, + array_append, + array_prepend, + array_has_all, + make_array, + }) + } +} + +impl UserDefinedPlanner for ArrayFunctionPlanner { + fn plan_binary_op( + &self, + expr: BinaryExpr, + schema: &DFSchema, + ) -> Result { + let BinaryExpr { op, left, right } = expr; + + if op == sqlparser::ast::BinaryOperator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + return Ok(PlannerSimplifyResult::Simplified( + self.array_concat.call(vec![left, right]), + )); + } else if left_list_ndims > right_list_ndims { + return Ok(PlannerSimplifyResult::Simplified( + self.array_append.call(vec![left, right]), + )); + } else if left_list_ndims < right_list_ndims { + return Ok(PlannerSimplifyResult::Simplified( + self.array_prepend.call(vec![left, right]), + )); + } + } else if matches!( + op, + sqlparser::ast::BinaryOperator::AtArrow + | sqlparser::ast::BinaryOperator::ArrowAt + ) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if op == sqlparser::ast::BinaryOperator::AtArrow { + // array1 @> array2 -> array_has_all(array1, array2) + return Ok(PlannerSimplifyResult::Simplified( + self.array_has_all.call(vec![left, right]), + )); + } else { + // array1 <@ array2 -> array_has_all(array2, array1) + return Ok(PlannerSimplifyResult::Simplified( + self.array_has_all.call(vec![right, left]), + )); + } + } + } + + Ok(PlannerSimplifyResult::OriginalBinaryExpr(BinaryExpr { + op, + left, + right, + })) + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result { + Ok(PlannerSimplifyResult::Simplified( + self.make_array.call(exprs), + )) + } +} + +pub struct FieldAccessPlanner { + get_field: Arc, + array_element: Arc, + array_slice: Arc, +} + +impl FieldAccessPlanner { + pub fn try_new(context_provider: &dyn ContextProvider) -> Result { + let Some(get_field) = context_provider.get_function_meta("get_field") else { + return internal_err!("get_feild not found"); + }; + let Some(array_element) = context_provider.get_function_meta("array_element") + else { + return internal_err!("array_element not found"); + }; + let Some(array_slice) = context_provider.get_function_meta("array_slice") else { + return internal_err!("array_slice not found"); + }; + + Ok(Self { + get_field, + array_element, + array_slice, + }) + } +} + +impl UserDefinedPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: FieldAccessExpr, + _schema: &DFSchema, + ) -> Result { + let FieldAccessExpr { expr, field_access } = expr; + + match field_access { + // expr["field"] => get_field(expr, "field") + GetFieldAccess::NamedStructField { name } => { + Ok(PlannerSimplifyResult::Simplified( + self.get_field.call(vec![expr, lit(name)]), + )) + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { + Ok(PlannerSimplifyResult::Simplified(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new( + AggregateFunction::NthValue, + agg_func + .args + .into_iter() + .chain(std::iter::once(*index)) + .collect(), + agg_func.distinct, + agg_func.filter, + agg_func.order_by, + agg_func.null_treatment, + ), + ))) + } + _ => Ok(PlannerSimplifyResult::Simplified( + self.array_element.call(vec![expr, *index]), + )), + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => Ok(PlannerSimplifyResult::Simplified( + self.array_slice.call(vec![expr, *start, *stop, *stride]), + )), + } + } +} + +fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + agg_func.func_def + == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + AggregateFunction::ArrayAgg, + ) +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 62582bf6b4b3..0f68ee44f6ab 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -18,9 +18,9 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; use datafusion_common::exec_err; -use datafusion_common::utils::list_ndims; +use datafusion_expr::planner::FieldAccessExpr; +use datafusion_expr::planner::PlannerSimplifyResult; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; -use std::sync::Arc; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, @@ -29,12 +29,11 @@ use datafusion_common::{ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, Like, Literal, Operator, ScalarUDF, TryCast, + lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, + Operator, TryCast, }; -use crate::planner::PlannerSimplifyResult; -use crate::planner::{ContextProvider, PlannerContext, SqlToRel, UserDefinedPlanner}; +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; mod binary_op; mod function; @@ -108,23 +107,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result { // try extension planers - let mut binary_expr = crate::planner::BinaryExpr { op, left, right }; - let num_planners = self.planners.len(); - for (i, planner) in self.planners.iter().enumerate() { + let mut binary_expr = datafusion_expr::planner::BinaryExpr { op, left, right }; + for planner in self.planners.iter() { match planner.plan_binary_op(binary_expr, schema)? { PlannerSimplifyResult::Simplified(expr) => { return Ok(expr); } - PlannerSimplifyResult::OriginalBinaryExpr(expr) - if i + 1 == num_planners => - { - let crate::planner::BinaryExpr { op, left, right } = expr; - return Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - self.parse_sql_binary_op(op)?, - Box::new(right), - ))); - } PlannerSimplifyResult::OriginalBinaryExpr(expr) => { binary_expr = expr; } @@ -134,7 +122,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - internal_err!("Unexpect to reach here") + let datafusion_expr::planner::BinaryExpr { op, left, right } = binary_expr; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + self.parse_sql_binary_op(op)?, + Box::new(right), + ))) } /// Generate a relational expression from a SQL expression @@ -289,17 +282,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - let mut field_access_expr = - crate::planner::FieldAccessExpr { expr, field_access }; - let num_planners = self.planners.len(); - for (i, planner) in self.planners.iter().enumerate() { + let mut field_access_expr = FieldAccessExpr { expr, field_access }; + for planner in self.planners.iter() { match planner.plan_field_access(field_access_expr, schema)? { PlannerSimplifyResult::Simplified(expr) => { return Ok(expr) } - PlannerSimplifyResult::OriginalFieldAccessExpr(_) if i + 1 == num_planners => { - return internal_err!("Expected a simplified result, but none was found") - } PlannerSimplifyResult::OriginalFieldAccessExpr(expr) => { field_access_expr = expr; } @@ -309,7 +297,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - internal_err!("Unexpect to reach here") + internal_err!("Expected a simplified result, but none was found") } SQLExpr::CompoundIdentifier(ids) => { @@ -931,210 +919,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -pub struct ArrayFunctionPlanner { - array_concat: Arc, - array_append: Arc, - array_prepend: Arc, - array_has_all: Arc, - make_array: Arc, -} - -impl ArrayFunctionPlanner { - pub fn try_new(context_provider: &dyn ContextProvider) -> Result { - let Some(array_concat) = context_provider.get_function_meta("array_concat") - else { - return internal_err!("array_concat not found"); - }; - let Some(array_append) = context_provider.get_function_meta("array_append") - else { - return internal_err!("array_append not found"); - }; - let Some(array_prepend) = context_provider.get_function_meta("array_prepend") - else { - return internal_err!("array_prepend not found"); - }; - let Some(array_has_all) = context_provider.get_function_meta("array_has_all") - else { - return internal_err!("array_has_all not found"); - }; - let Some(make_array) = context_provider.get_function_meta("make_array") else { - return internal_err!("make_array not found"); - }; - - Ok(Self { - array_concat, - array_append, - array_prepend, - array_has_all, - make_array, - }) - } -} -impl UserDefinedPlanner for ArrayFunctionPlanner { - fn plan_binary_op( - &self, - expr: crate::planner::BinaryExpr, - schema: &DFSchema, - ) -> Result { - let crate::planner::BinaryExpr { op, left, right } = expr; - - if op == sqlparser::ast::BinaryOperator::StringConcat { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); - - // Rewrite string concat operator to function based on types - // if we get list || list then we rewrite it to array_concat() - // if we get list || non-list then we rewrite it to array_append() - // if we get non-list || list then we rewrite it to array_prepend() - // if we get string || string then we rewrite it to concat() - - // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. - // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. - if left_list_ndims + right_list_ndims == 0 { - // TODO: concat function ignore null, but string concat takes null into consideration - // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` - } else if left_list_ndims == right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_concat.call(vec![left, right]), - )); - } else if left_list_ndims > right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_append.call(vec![left, right]), - )); - } else if left_list_ndims < right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_prepend.call(vec![left, right]), - )); - } - } else if matches!( - op, - sqlparser::ast::BinaryOperator::AtArrow - | sqlparser::ast::BinaryOperator::ArrowAt - ) { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - let left_list_ndims = list_ndims(&left_type); - let right_list_ndims = list_ndims(&right_type); - // if both are list - if left_list_ndims > 0 && right_list_ndims > 0 { - if op == sqlparser::ast::BinaryOperator::AtArrow { - // array1 @> array2 -> array_has_all(array1, array2) - return Ok(PlannerSimplifyResult::Simplified( - self.array_has_all.call(vec![left, right]), - )); - } else { - // array1 <@ array2 -> array_has_all(array2, array1) - return Ok(PlannerSimplifyResult::Simplified( - self.array_has_all.call(vec![right, left]), - )); - } - } - } - - Ok(PlannerSimplifyResult::OriginalBinaryExpr( - crate::planner::BinaryExpr { op, left, right }, - )) - } - - fn plan_array_literal( - &self, - exprs: Vec, - _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::Simplified( - self.make_array.call(exprs), - )) - } -} - -pub struct FieldAccessPlanner { - get_field: Arc, - array_element: Arc, - array_slice: Arc, -} - -impl FieldAccessPlanner { - pub fn try_new(context_provider: &dyn ContextProvider) -> Result { - let Some(get_field) = context_provider.get_function_meta("get_field") else { - return internal_err!("get_feild not found"); - }; - let Some(array_element) = context_provider.get_function_meta("array_element") - else { - return internal_err!("array_element not found"); - }; - let Some(array_slice) = context_provider.get_function_meta("array_slice") else { - return internal_err!("array_slice not found"); - }; - - Ok(Self { - get_field, - array_element, - array_slice, - }) - } -} - -impl UserDefinedPlanner for FieldAccessPlanner { - fn plan_field_access( - &self, - expr: crate::planner::FieldAccessExpr, - _schema: &DFSchema, - ) -> Result { - let crate::planner::FieldAccessExpr { expr, field_access } = expr; - - match field_access { - // expr["field"] => get_field(expr, "field") - GetFieldAccess::NamedStructField { name } => { - Ok(PlannerSimplifyResult::Simplified( - self.get_field.call(vec![expr, lit(name)]), - )) - } - // expr[idx] ==> array_element(expr, idx) - GetFieldAccess::ListIndex { key: index } => { - match expr { - // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerSimplifyResult::Simplified(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, - agg_func - .args - .into_iter() - .chain(std::iter::once(*index)) - .collect(), - agg_func.distinct, - agg_func.filter, - agg_func.order_by, - agg_func.null_treatment, - ), - ))) - } - _ => Ok(PlannerSimplifyResult::Simplified( - self.array_element.call(vec![expr, *index]), - )), - } - } - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - GetFieldAccess::ListRange { - start, - stop, - stride, - } => Ok(PlannerSimplifyResult::Simplified( - self.array_slice.call(vec![expr, *start, *stop, *stride]), - )), - } - } -} - -fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) -} - #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 225e4736c621..625dd8236f92 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, PlannerSimplifyResult, SqlToRel}; +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; @@ -24,6 +24,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::planner::PlannerSimplifyResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; @@ -144,17 +145,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; let mut exprs = values; - let num_planners = self.planners.len(); - for (i, planner) in self.planners.iter().enumerate() { + for planner in self.planners.iter() { match planner.plan_array_literal(exprs, schema)? { PlannerSimplifyResult::Simplified(expr) => { return Ok(expr); } - PlannerSimplifyResult::OriginalArray(_) if i + 1 == num_planners => { - return internal_err!( - "Expected a simplified result, but none was found" - ) - } PlannerSimplifyResult::OriginalArray(values) => exprs = values, _ => { return exec_err!( @@ -164,7 +159,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - internal_err!("Unexpect to reach here") + internal_err!("Expected a simplified result, but none was found") } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7fbe01b664ed..eab9f49c051b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,18 +21,15 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::{GetFieldAccess, WindowUDF}; +use datafusion_expr::planner::UserDefinedPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; -use crate::expr::{ArrayFunctionPlanner, FieldAccessPlanner}; -use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, @@ -40,64 +37,11 @@ use datafusion_common::{ }; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; -use datafusion_expr::TableSource; -use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; +use datafusion_expr::{col, Expr}; use crate::utils::make_decimal_type; -/// The ContextProvider trait allows the query planner to obtain meta-data about tables and -/// functions referenced in SQL statements -pub trait ContextProvider { - /// Getter for a datasource - fn get_table_source(&self, name: TableReference) -> Result>; - - fn get_file_type(&self, _ext: &str) -> Result> { - not_impl_err!("Registered file types are not supported") - } - - /// Getter for a table function - fn get_table_function_source( - &self, - _name: &str, - _args: Vec, - ) -> Result> { - not_impl_err!("Table Functions are not supported") - } - - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). - /// The [`ContextProvider`] provides a way to "hide" this dependency. - fn create_cte_work_table( - &self, - _name: &str, - _schema: SchemaRef, - ) -> Result> { - not_impl_err!("Recursive CTE is not implemented") - } - - /// Getter for a UDF description - fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description - fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF - fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; - - /// Get configuration options - fn options(&self) -> &ConfigOptions; - - /// Get all user defined scalar function names - fn udf_names(&self) -> Vec; - - /// Get all user defined aggregate function names - fn udaf_names(&self) -> Vec; - - /// Get all user defined window function names - fn udwf_names(&self) -> Vec; -} +pub use datafusion_expr::planner::ContextProvider; /// SQL parser options #[derive(Debug)] @@ -237,63 +181,55 @@ impl PlannerContext { } } -/// This trait allows users to customize the behavior of the SQL planner -pub trait UserDefinedPlanner { - /// Plan the binary operation between two expressions, return None if not possible - fn plan_binary_op( - &self, - expr: BinaryExpr, - _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) - } - - /// Plan the field access expression, return None if not possible - fn plan_field_access( - &self, - expr: FieldAccessExpr, - _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) - } - - fn plan_array_literal( - &self, - exprs: Vec, - _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalArray(exprs)) - } -} - -pub struct BinaryExpr { - pub op: sqlparser::ast::BinaryOperator, - pub left: Expr, - pub right: Expr, -} - -pub struct FieldAccessExpr { - pub field_access: GetFieldAccess, - pub expr: Expr, -} - -pub enum PlannerSimplifyResult { - /// The function call was simplified to an entirely new Expr - Simplified(Expr), - /// the function call could not be simplified, and the arguments - /// are return unmodified. - OriginalBinaryExpr(BinaryExpr), - OriginalFieldAccessExpr(FieldAccessExpr), - OriginalArray(Vec), -} - -pub enum PlanSimplifyResult { - /// The function call was simplified to an entirely new Expr - Simplified(Expr), - /// the function call could not be simplified, and the arguments - /// are return unmodified. - Original(Vec), -} +// /// This trait allows users to customize the behavior of the SQL planner +// pub trait UserDefinedPlanner { +// /// Plan the binary operation between two expressions, return None if not possible +// fn plan_binary_op( +// &self, +// expr: BinaryExpr, +// _schema: &DFSchema, +// ) -> Result { +// Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) +// } + +// /// Plan the field access expression, return None if not possible +// fn plan_field_access( +// &self, +// expr: FieldAccessExpr, +// _schema: &DFSchema, +// ) -> Result { +// Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) +// } + +// fn plan_array_literal( +// &self, +// exprs: Vec, +// _schema: &DFSchema, +// ) -> Result { +// Ok(PlannerSimplifyResult::OriginalArray(exprs)) +// } +// } + +// pub struct BinaryExpr { +// pub op: sqlparser::ast::BinaryOperator, +// pub left: Expr, +// pub right: Expr, +// } + +// pub struct FieldAccessExpr { +// pub field_access: GetFieldAccess, +// pub expr: Expr, +// } + +// pub enum PlannerSimplifyResult { +// /// The function call was simplified to an entirely new Expr +// Simplified(Expr), +// /// the function call could not be simplified, and the arguments +// /// are return unmodified. +// OriginalBinaryExpr(BinaryExpr), +// OriginalFieldAccessExpr(FieldAccessExpr), +// OriginalArray(Vec), +// } /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { @@ -322,10 +258,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; - let array_planner = - Arc::new(ArrayFunctionPlanner::try_new(context_provider).unwrap()) as _; - let field_access_planner = - Arc::new(FieldAccessPlanner::try_new(context_provider).unwrap()) as _; SqlToRel { context_provider, @@ -333,9 +265,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { normalizer: IdentNormalizer::new(normalize), planners: vec![], } - // todo put this somewhere else - .with_user_defined_planner(array_planner) - .with_user_defined_planner(field_access_planner) } pub fn build_schema(&self, columns: Vec) -> Result { From d38670239ff4a32b173309f79c225b521c2956c4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 13:46:19 +0800 Subject: [PATCH 07/17] license Signed-off-by: jayzhan211 --- datafusion/expr/src/planner.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index df0e156dee52..a1d666cd02ff 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -1,3 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL query planner module + use std::sync::Arc; use arrow::datatypes::{DataType, SchemaRef}; From 24372e4b4b1073e66a30cffb63eca020fc683553 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 14:00:17 +0800 Subject: [PATCH 08/17] cleanup Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 1 - .../core/src/execution/session_state.rs | 6 +- datafusion/functions-array/Cargo.toml | 1 - datafusion/functions-array/src/planner.rs | 138 +++++------------- datafusion/sql/src/planner.rs | 50 ------- 5 files changed, 40 insertions(+), 156 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5292a0dbab1b..f5ea18a8a1d3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1319,7 +1319,6 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", - "datafusion-sql", "itertools", "log", "paste", diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 752762010410..01b52e7a34db 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -953,11 +953,9 @@ impl SessionState { // register crate of array expressions (if enabled) #[cfg(feature = "array_expressions")] { - let array_planner = - Arc::new(ArrayFunctionPlanner::try_new(provider).unwrap()) as _; + let array_planner = Arc::new(ArrayFunctionPlanner::default()) as _; - let field_access_planner = - Arc::new(FieldAccessPlanner::try_new(provider).unwrap()) as _; + let field_access_planner = Arc::new(FieldAccessPlanner::default()) as _; query .with_user_defined_planner(array_planner) diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 784677e66462..faf5bac3caf8 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -49,7 +49,6 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } -datafusion-sql = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index 8dbfaba87436..92985f3aee06 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -15,57 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use datafusion_common::{internal_err, utils::list_ndims, DFSchema, Result}; +use datafusion_common::{utils::list_ndims, DFSchema, Result}; use datafusion_expr::{ - lit, - planner::{ - BinaryExpr, ContextProvider, FieldAccessExpr, PlannerSimplifyResult, - UserDefinedPlanner, - }, - AggregateFunction, Expr, ExprSchemable, GetFieldAccess, ScalarUDF, + planner::{BinaryExpr, FieldAccessExpr, PlannerSimplifyResult, UserDefinedPlanner}, + AggregateFunction, Expr, ExprSchemable, GetFieldAccess, }; +use datafusion_functions::expr_fn::get_field; -pub struct ArrayFunctionPlanner { - array_concat: Arc, - array_append: Arc, - array_prepend: Arc, - array_has_all: Arc, - make_array: Arc, -} - -impl ArrayFunctionPlanner { - pub fn try_new(context_provider: &dyn ContextProvider) -> Result { - let Some(array_concat) = context_provider.get_function_meta("array_concat") - else { - return internal_err!("array_concat not found"); - }; - let Some(array_append) = context_provider.get_function_meta("array_append") - else { - return internal_err!("array_append not found"); - }; - let Some(array_prepend) = context_provider.get_function_meta("array_prepend") - else { - return internal_err!("array_prepend not found"); - }; - let Some(array_has_all) = context_provider.get_function_meta("array_has_all") - else { - return internal_err!("array_has_all not found"); - }; - let Some(make_array) = context_provider.get_function_meta("make_array") else { - return internal_err!("make_array not found"); - }; +use crate::{ + array_has::array_has_all, + expr_fn::{array_append, array_concat, array_prepend}, + extract::{array_element, array_slice}, + make_array::make_array, +}; - Ok(Self { - array_concat, - array_append, - array_prepend, - array_has_all, - make_array, - }) - } -} +#[derive(Default)] +pub struct ArrayFunctionPlanner {} impl UserDefinedPlanner for ArrayFunctionPlanner { fn plan_binary_op( @@ -93,17 +58,15 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { // TODO: concat function ignore null, but string concat takes null into consideration // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` } else if left_list_ndims == right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_concat.call(vec![left, right]), - )); + return Ok(PlannerSimplifyResult::Simplified(array_concat(vec![ + left, right, + ]))); } else if left_list_ndims > right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_append.call(vec![left, right]), - )); + return Ok(PlannerSimplifyResult::Simplified(array_append(left, right))); } else if left_list_ndims < right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified( - self.array_prepend.call(vec![left, right]), - )); + return Ok(PlannerSimplifyResult::Simplified(array_prepend( + left, right, + ))); } } else if matches!( op, @@ -118,14 +81,14 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { if left_list_ndims > 0 && right_list_ndims > 0 { if op == sqlparser::ast::BinaryOperator::AtArrow { // array1 @> array2 -> array_has_all(array1, array2) - return Ok(PlannerSimplifyResult::Simplified( - self.array_has_all.call(vec![left, right]), - )); + return Ok(PlannerSimplifyResult::Simplified(array_has_all( + left, right, + ))); } else { // array1 <@ array2 -> array_has_all(array2, array1) - return Ok(PlannerSimplifyResult::Simplified( - self.array_has_all.call(vec![right, left]), - )); + return Ok(PlannerSimplifyResult::Simplified(array_has_all( + right, left, + ))); } } } @@ -142,38 +105,12 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { exprs: Vec, _schema: &DFSchema, ) -> Result { - Ok(PlannerSimplifyResult::Simplified( - self.make_array.call(exprs), - )) + Ok(PlannerSimplifyResult::Simplified(make_array(exprs))) } } -pub struct FieldAccessPlanner { - get_field: Arc, - array_element: Arc, - array_slice: Arc, -} - -impl FieldAccessPlanner { - pub fn try_new(context_provider: &dyn ContextProvider) -> Result { - let Some(get_field) = context_provider.get_function_meta("get_field") else { - return internal_err!("get_feild not found"); - }; - let Some(array_element) = context_provider.get_function_meta("array_element") - else { - return internal_err!("array_element not found"); - }; - let Some(array_slice) = context_provider.get_function_meta("array_slice") else { - return internal_err!("array_slice not found"); - }; - - Ok(Self { - get_field, - array_element, - array_slice, - }) - } -} +#[derive(Default)] +pub struct FieldAccessPlanner {} impl UserDefinedPlanner for FieldAccessPlanner { fn plan_field_access( @@ -186,9 +123,7 @@ impl UserDefinedPlanner for FieldAccessPlanner { match field_access { // expr["field"] => get_field(expr, "field") GetFieldAccess::NamedStructField { name } => { - Ok(PlannerSimplifyResult::Simplified( - self.get_field.call(vec![expr, lit(name)]), - )) + Ok(PlannerSimplifyResult::Simplified(get_field(expr, name))) } // expr[idx] ==> array_element(expr, idx) GetFieldAccess::ListIndex { key: index } => { @@ -210,9 +145,9 @@ impl UserDefinedPlanner for FieldAccessPlanner { ), ))) } - _ => Ok(PlannerSimplifyResult::Simplified( - self.array_element.call(vec![expr, *index]), - )), + _ => Ok(PlannerSimplifyResult::Simplified(array_element( + expr, *index, + ))), } } // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) @@ -220,9 +155,12 @@ impl UserDefinedPlanner for FieldAccessPlanner { start, stop, stride, - } => Ok(PlannerSimplifyResult::Simplified( - self.array_slice.call(vec![expr, *start, *stop, *stride]), - )), + } => Ok(PlannerSimplifyResult::Simplified(array_slice( + expr, + *start, + *stop, + Some(*stride), + ))), } } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index eab9f49c051b..ec9dfc2af8a2 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -181,56 +181,6 @@ impl PlannerContext { } } -// /// This trait allows users to customize the behavior of the SQL planner -// pub trait UserDefinedPlanner { -// /// Plan the binary operation between two expressions, return None if not possible -// fn plan_binary_op( -// &self, -// expr: BinaryExpr, -// _schema: &DFSchema, -// ) -> Result { -// Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) -// } - -// /// Plan the field access expression, return None if not possible -// fn plan_field_access( -// &self, -// expr: FieldAccessExpr, -// _schema: &DFSchema, -// ) -> Result { -// Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) -// } - -// fn plan_array_literal( -// &self, -// exprs: Vec, -// _schema: &DFSchema, -// ) -> Result { -// Ok(PlannerSimplifyResult::OriginalArray(exprs)) -// } -// } - -// pub struct BinaryExpr { -// pub op: sqlparser::ast::BinaryOperator, -// pub left: Expr, -// pub right: Expr, -// } - -// pub struct FieldAccessExpr { -// pub field_access: GetFieldAccess, -// pub expr: Expr, -// } - -// pub enum PlannerSimplifyResult { -// /// The function call was simplified to an entirely new Expr -// Simplified(Expr), -// /// the function call could not be simplified, and the arguments -// /// are return unmodified. -// OriginalBinaryExpr(BinaryExpr), -// OriginalFieldAccessExpr(FieldAccessExpr), -// OriginalArray(Vec), -// } - /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, From e8890091cc74b3bbe47f46062ae6ea7ba6e6e82c Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 15:22:33 +0800 Subject: [PATCH 09/17] fix Signed-off-by: jayzhan211 --- datafusion/core/src/execution/session_state.rs | 7 ++++--- datafusion/expr/src/planner.rs | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 01b52e7a34db..ac94ee61fcb2 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -67,7 +67,6 @@ use datafusion_expr::{ AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; -use datafusion_functions_array::planner::{ArrayFunctionPlanner, FieldAccessPlanner}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, @@ -953,9 +952,11 @@ impl SessionState { // register crate of array expressions (if enabled) #[cfg(feature = "array_expressions")] { - let array_planner = Arc::new(ArrayFunctionPlanner::default()) as _; + let array_planner = + Arc::new(functions_array::planner::ArrayFunctionPlanner::default()) as _; - let field_access_planner = Arc::new(FieldAccessPlanner::default()) as _; + let field_access_planner = + Arc::new(functions_array::planner::FieldAccessPlanner::default()) as _; query .with_user_defined_planner(array_planner) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index a1d666cd02ff..64ed85d7e406 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -83,7 +83,7 @@ pub trait ContextProvider { /// This trait allows users to customize the behavior of the SQL planner pub trait UserDefinedPlanner { - /// Plan the binary operation between two expressions, return None if not possible + /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible fn plan_binary_op( &self, expr: BinaryExpr, @@ -92,7 +92,7 @@ pub trait UserDefinedPlanner { Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) } - /// Plan the field access expression, return None if not possible + /// Plan the field access expression, returns OriginalFieldAccessExpr if not possible fn plan_field_access( &self, expr: FieldAccessExpr, @@ -101,6 +101,7 @@ pub trait UserDefinedPlanner { Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) } + // Plan the array literal, returns OriginalArray if not possible fn plan_array_literal( &self, exprs: Vec, From 10477a8b1a63f9bee373dd28756f7c6ba652abd6 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 16:37:28 +0800 Subject: [PATCH 10/17] change nested array test Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/array.slt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7917f1d78da8..2ce0a29986c0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4910,10 +4910,11 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 +# TODO: dimensions 20 is the maximum without stack overflow, find a way to enable deep nested array query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]); ---- -3 21 +3 20 # array_ndims scalar function #3 query II From b662b1e9cd90500d3c1a9c19aa04ca2e6b08d2de Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 20:44:18 +0800 Subject: [PATCH 11/17] address comment Signed-off-by: jayzhan211 --- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/planner.rs | 42 +++++++++------- datafusion/functions-array/Cargo.toml | 1 - datafusion/functions-array/src/planner.rs | 60 +++++++++-------------- datafusion/sql/src/expr/mod.rs | 27 ++++------ datafusion/sql/src/expr/value.rs | 14 ++---- datafusion/sql/src/planner.rs | 6 +-- 7 files changed, 65 insertions(+), 86 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 38f0617fc5fa..5f1d3c9d5c6b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -82,6 +82,7 @@ pub use partition_evaluator::PartitionEvaluator; pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; +pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 64ed85d7e406..99aea9956b0e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! SQL query planner module +//! [`ContextProvider`] and [`UserDefinedPlanner`] APIs to customize SQL query planning use std::sync::Arc; @@ -27,8 +27,9 @@ use datafusion_common::{ use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; -/// The ContextProvider trait allows the query planner to obtain meta-data about tables and -/// functions referenced in SQL statements +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on other +/// DataFusion structures pub trait ContextProvider { /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; @@ -82,23 +83,23 @@ pub trait ContextProvider { } /// This trait allows users to customize the behavior of the SQL planner -pub trait UserDefinedPlanner { +pub trait UserDefinedSQLPlanner { /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible fn plan_binary_op( &self, - expr: BinaryExpr, + expr: RawBinaryExpr, _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalBinaryExpr(expr)) + ) -> Result> { + Ok(PlannerResult::Original(expr)) } /// Plan the field access expression, returns OriginalFieldAccessExpr if not possible fn plan_field_access( &self, - expr: FieldAccessExpr, + expr: RawFieldAccessExpr, _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalFieldAccessExpr(expr)) + ) -> Result> { + Ok(PlannerResult::Original(expr)) } // Plan the array literal, returns OriginalArray if not possible @@ -106,28 +107,33 @@ pub trait UserDefinedPlanner { &self, exprs: Vec, _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::OriginalArray(exprs)) + ) -> Result>> { + Ok(PlannerResult::Original(exprs)) } } -pub struct BinaryExpr { +/// An operator with two arguments to plan +/// +/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST operator. +/// This structure is used by [`UserDefinedPlanner`] to plan operators with custom expressions. +pub struct RawBinaryExpr { pub op: sqlparser::ast::BinaryOperator, pub left: Expr, pub right: Expr, } -pub struct FieldAccessExpr { +/// An expression with GetFieldAccess to plan +/// +/// This structure is used by [`UserDefinedPlanner`] to plan operators with custom expressions. +pub struct RawFieldAccessExpr { pub field_access: GetFieldAccess, pub expr: Expr, } -pub enum PlannerSimplifyResult { +pub enum PlannerResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), /// the function call could not be simplified, and the arguments /// are return unmodified. - OriginalBinaryExpr(BinaryExpr), - OriginalFieldAccessExpr(FieldAccessExpr), - OriginalArray(Vec), + Original(T), } diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index faf5bac3caf8..eb1ef9e03f31 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -52,7 +52,6 @@ datafusion-functions = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" -sqlparser = { workspace = true } [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index 92985f3aee06..d0faf254df42 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] + use datafusion_common::{utils::list_ndims, DFSchema, Result}; use datafusion_expr::{ - planner::{BinaryExpr, FieldAccessExpr, PlannerSimplifyResult, UserDefinedPlanner}, - AggregateFunction, Expr, ExprSchemable, GetFieldAccess, + planner::{PlannerResult, RawBinaryExpr, RawFieldAccessExpr, UserDefinedSQLPlanner}, + sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, }; use datafusion_functions::expr_fn::get_field; @@ -32,13 +34,13 @@ use crate::{ #[derive(Default)] pub struct ArrayFunctionPlanner {} -impl UserDefinedPlanner for ArrayFunctionPlanner { +impl UserDefinedSQLPlanner for ArrayFunctionPlanner { fn plan_binary_op( &self, - expr: BinaryExpr, + expr: RawBinaryExpr, schema: &DFSchema, - ) -> Result { - let BinaryExpr { op, left, right } = expr; + ) -> Result> { + let RawBinaryExpr { op, left, right } = expr; if op == sqlparser::ast::BinaryOperator::StringConcat { let left_type = left.get_type(schema)?; @@ -58,15 +60,11 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { // TODO: concat function ignore null, but string concat takes null into consideration // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` } else if left_list_ndims == right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified(array_concat(vec![ - left, right, - ]))); + return Ok(PlannerResult::Simplified(array_concat(vec![left, right]))); } else if left_list_ndims > right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified(array_append(left, right))); + return Ok(PlannerResult::Simplified(array_append(left, right))); } else if left_list_ndims < right_list_ndims { - return Ok(PlannerSimplifyResult::Simplified(array_prepend( - left, right, - ))); + return Ok(PlannerResult::Simplified(array_prepend(left, right))); } } else if matches!( op, @@ -81,56 +79,48 @@ impl UserDefinedPlanner for ArrayFunctionPlanner { if left_list_ndims > 0 && right_list_ndims > 0 { if op == sqlparser::ast::BinaryOperator::AtArrow { // array1 @> array2 -> array_has_all(array1, array2) - return Ok(PlannerSimplifyResult::Simplified(array_has_all( - left, right, - ))); + return Ok(PlannerResult::Simplified(array_has_all(left, right))); } else { // array1 <@ array2 -> array_has_all(array2, array1) - return Ok(PlannerSimplifyResult::Simplified(array_has_all( - right, left, - ))); + return Ok(PlannerResult::Simplified(array_has_all(right, left))); } } } - Ok(PlannerSimplifyResult::OriginalBinaryExpr(BinaryExpr { - op, - left, - right, - })) + Ok(PlannerResult::Original(RawBinaryExpr { op, left, right })) } fn plan_array_literal( &self, exprs: Vec, _schema: &DFSchema, - ) -> Result { - Ok(PlannerSimplifyResult::Simplified(make_array(exprs))) + ) -> Result>> { + Ok(PlannerResult::Simplified(make_array(exprs))) } } #[derive(Default)] pub struct FieldAccessPlanner {} -impl UserDefinedPlanner for FieldAccessPlanner { +impl UserDefinedSQLPlanner for FieldAccessPlanner { fn plan_field_access( &self, - expr: FieldAccessExpr, + expr: RawFieldAccessExpr, _schema: &DFSchema, - ) -> Result { - let FieldAccessExpr { expr, field_access } = expr; + ) -> Result> { + let RawFieldAccessExpr { expr, field_access } = expr; match field_access { // expr["field"] => get_field(expr, "field") GetFieldAccess::NamedStructField { name } => { - Ok(PlannerSimplifyResult::Simplified(get_field(expr, name))) + Ok(PlannerResult::Simplified(get_field(expr, name))) } // expr[idx] ==> array_element(expr, idx) GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerSimplifyResult::Simplified(Expr::AggregateFunction( + Ok(PlannerResult::Simplified(Expr::AggregateFunction( datafusion_expr::expr::AggregateFunction::new( AggregateFunction::NthValue, agg_func @@ -145,9 +135,7 @@ impl UserDefinedPlanner for FieldAccessPlanner { ), ))) } - _ => Ok(PlannerSimplifyResult::Simplified(array_element( - expr, *index, - ))), + _ => Ok(PlannerResult::Simplified(array_element(expr, *index))), } } // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) @@ -155,7 +143,7 @@ impl UserDefinedPlanner for FieldAccessPlanner { start, stop, stride, - } => Ok(PlannerSimplifyResult::Simplified(array_slice( + } => Ok(PlannerResult::Simplified(array_slice( expr, *start, *stop, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0f68ee44f6ab..7d11e7eb2787 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,9 +17,8 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; -use datafusion_common::exec_err; -use datafusion_expr::planner::FieldAccessExpr; -use datafusion_expr::planner::PlannerSimplifyResult; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::RawFieldAccessExpr; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -107,22 +106,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result { // try extension planers - let mut binary_expr = datafusion_expr::planner::BinaryExpr { op, left, right }; + let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; for planner in self.planners.iter() { match planner.plan_binary_op(binary_expr, schema)? { - PlannerSimplifyResult::Simplified(expr) => { + PlannerResult::Simplified(expr) => { return Ok(expr); } - PlannerSimplifyResult::OriginalBinaryExpr(expr) => { + PlannerResult::Original(expr) => { binary_expr = expr; } - _ => { - return exec_err!("Unexpected result encountered. Did you expect an OriginalBinaryExpr?") - } } } - let datafusion_expr::planner::BinaryExpr { op, left, right } = binary_expr; + let datafusion_expr::planner::RawBinaryExpr { op, left, right } = binary_expr; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), self.parse_sql_binary_op(op)?, @@ -282,18 +278,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - let mut field_access_expr = FieldAccessExpr { expr, field_access }; + let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; for planner in self.planners.iter() { match planner.plan_field_access(field_access_expr, schema)? { - PlannerSimplifyResult::Simplified(expr) => { - return Ok(expr) - } - PlannerSimplifyResult::OriginalFieldAccessExpr(expr) => { + PlannerResult::Simplified(expr) => return Ok(expr), + PlannerResult::Original(expr) => { field_access_expr = expr; } - _ => { - return exec_err!("Unexpected result encountered. Did you expect an OriginalFieldAccessExpr?") - } } } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 625dd8236f92..a8f4dad7f1fa 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -20,11 +20,10 @@ use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, - ScalarValue, + internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Placeholder}; -use datafusion_expr::planner::PlannerSimplifyResult; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; @@ -147,15 +146,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut exprs = values; for planner in self.planners.iter() { match planner.plan_array_literal(exprs, schema)? { - PlannerSimplifyResult::Simplified(expr) => { + PlannerResult::Simplified(expr) => { return Ok(expr); } - PlannerSimplifyResult::OriginalArray(values) => exprs = values, - _ => { - return exec_err!( - "Unexpected result encountered. Did you expect an OriginalArray?" - ) - } + PlannerResult::Original(values) => exprs = values, } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ec9dfc2af8a2..443cd64a940c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -24,7 +24,7 @@ use arrow_schema::*; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::planner::UserDefinedPlanner; +use datafusion_expr::planner::UserDefinedSQLPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -187,7 +187,7 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, /// user defined planner extensions - pub(crate) planners: Vec>, + pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -199,7 +199,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// add an user defined planner pub fn with_user_defined_planner( mut self, - planner: Arc, + planner: Arc, ) -> Self { self.planners.push(planner); self From 160e032f1916ab191a9b69761f273c7880f02873 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 30 Jun 2024 21:05:12 +0800 Subject: [PATCH 12/17] fix stack overflow issue Signed-off-by: jayzhan211 --- datafusion/sql/src/expr/value.rs | 10 ++++++++++ datafusion/sqllogictest/test_files/array.slt | 5 ++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a8f4dad7f1fa..c4a765014bd7 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -131,6 +131,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + // IMPORTANT: Keep sql_array_literal's function body small to prevent stack overflow + // This function is recursively called, potentially leading to deep call stacks. pub(super) fn sql_array_literal( &self, elements: Vec, @@ -143,6 +145,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; + self.try_plan_array_literal(values, schema) + } + + fn try_plan_array_literal( + &self, + values: Vec, + schema: &DFSchema, + ) -> Result { let mut exprs = values; for planner in self.planners.iter() { match planner.plan_array_literal(exprs, schema)? { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 2ce0a29986c0..7917f1d78da8 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -4910,11 +4910,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -# TODO: dimensions 20 is the maximum without stack overflow, find a way to enable deep nested array query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]); +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ---- -3 20 +3 21 # array_ndims scalar function #3 query II From c4e2b33adbb34311d1cae0dff813fa31de4848b2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 1 Jul 2024 08:35:11 +0800 Subject: [PATCH 13/17] upd cli Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 101 +++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f5ea18a8a1d3..5fc8dbcfdfb3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -387,7 +387,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -757,9 +757,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake2" @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.99" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" +checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" dependencies = [ "jobserver", "libc", @@ -981,7 +981,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", ] @@ -1099,7 +1099,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -1262,7 +1262,7 @@ dependencies = [ "paste", "serde_json", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", "strum_macros 0.26.4", ] @@ -1322,7 +1322,6 @@ dependencies = [ "itertools", "log", "paste", - "sqlparser", ] [[package]] @@ -1427,7 +1426,7 @@ dependencies = [ "log", "regex", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", ] [[package]] @@ -1505,9 +1504,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "endian-type" @@ -1686,7 +1685,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2254,9 +2253,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.38" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" +checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" dependencies = [ "cc", "libc", @@ -2268,7 +2267,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] @@ -2290,9 +2289,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4_flex" @@ -2332,9 +2331,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mimalloc" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" +checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" dependencies = [ "libmimalloc-sys", ] @@ -2407,9 +2406,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2483,9 +2482,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" dependencies = [ "memchr", ] @@ -2699,7 +2698,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -2913,7 +2912,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] @@ -3096,7 +3095,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -3265,7 +3264,7 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3311,14 +3310,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.119" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" dependencies = [ "itoa", "ryu", @@ -3446,7 +3445,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3475,9 +3474,9 @@ checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros 0.26.4", ] @@ -3492,7 +3491,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3505,14 +3504,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] name = "subtle" -version = "2.6.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -3527,9 +3526,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.67" +version = "2.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" dependencies = [ "proc-macro2", "quote", @@ -3592,7 +3591,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3647,9 +3646,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" dependencies = [ "tinyvec_macros", ] @@ -3687,7 +3686,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3784,7 +3783,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3829,7 +3828,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] @@ -3908,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.8.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" dependencies = [ "getrandom", "serde", @@ -3983,7 +3982,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-shared", ] @@ -4017,7 +4016,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4282,7 +4281,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.67", + "syn 2.0.68", ] [[package]] From dafd53ed8eaaeee6fa8175e9300cc68f7e55e27d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 1 Jul 2024 09:13:36 +0800 Subject: [PATCH 14/17] fix doc Signed-off-by: jayzhan211 --- datafusion/expr/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 99aea9956b0e..43b42ce50979 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ContextProvider`] and [`UserDefinedPlanner`] APIs to customize SQL query planning +//! [`ContextProvider`] and [`UserDefinedSQLPlanner`] APIs to customize SQL query planning use std::sync::Arc; From 5e9af66eada708bdfc738bf774172f0e36358073 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 1 Jul 2024 11:51:56 +0800 Subject: [PATCH 15/17] fix doc Signed-off-by: jayzhan211 --- datafusion/expr/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 43b42ce50979..9f427b73da0d 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -115,7 +115,7 @@ pub trait UserDefinedSQLPlanner { /// An operator with two arguments to plan /// /// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST operator. -/// This structure is used by [`UserDefinedPlanner`] to plan operators with custom expressions. +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with custom expressions. pub struct RawBinaryExpr { pub op: sqlparser::ast::BinaryOperator, pub left: Expr, @@ -124,7 +124,7 @@ pub struct RawBinaryExpr { /// An expression with GetFieldAccess to plan /// -/// This structure is used by [`UserDefinedPlanner`] to plan operators with custom expressions. +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with custom expressions. pub struct RawFieldAccessExpr { pub field_access: GetFieldAccess, pub expr: Expr, From 596362681f6541d61f44e0ba0f4844a1fc25f03e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 1 Jul 2024 07:09:54 -0400 Subject: [PATCH 16/17] Rename PlannerResult::Simplified to PlannerResult::Planned --- datafusion/expr/src/planner.rs | 14 +++++++++----- datafusion/functions-array/src/planner.rs | 20 ++++++++++---------- datafusion/sql/src/expr/mod.rs | 4 ++-- datafusion/sql/src/expr/value.rs | 2 +- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 9f427b73da0d..910b74763835 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -114,8 +114,11 @@ pub trait UserDefinedSQLPlanner { /// An operator with two arguments to plan /// -/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST operator. -/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with custom expressions. +/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST +/// operator. +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. pub struct RawBinaryExpr { pub op: sqlparser::ast::BinaryOperator, pub left: Expr, @@ -124,15 +127,16 @@ pub struct RawBinaryExpr { /// An expression with GetFieldAccess to plan /// -/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with custom expressions. +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. pub struct RawFieldAccessExpr { pub field_access: GetFieldAccess, pub expr: Expr, } pub enum PlannerResult { - /// The function call was simplified to an entirely new Expr - Simplified(Expr), + /// The function call was successfully planned as a new Expr + Planned(Expr), /// the function call could not be simplified, and the arguments /// are return unmodified. Original(T), diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index d0faf254df42..f33ee56582cf 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -60,11 +60,11 @@ impl UserDefinedSQLPlanner for ArrayFunctionPlanner { // TODO: concat function ignore null, but string concat takes null into consideration // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` } else if left_list_ndims == right_list_ndims { - return Ok(PlannerResult::Simplified(array_concat(vec![left, right]))); + return Ok(PlannerResult::Planned(array_concat(vec![left, right]))); } else if left_list_ndims > right_list_ndims { - return Ok(PlannerResult::Simplified(array_append(left, right))); + return Ok(PlannerResult::Planned(array_append(left, right))); } else if left_list_ndims < right_list_ndims { - return Ok(PlannerResult::Simplified(array_prepend(left, right))); + return Ok(PlannerResult::Planned(array_prepend(left, right))); } } else if matches!( op, @@ -79,10 +79,10 @@ impl UserDefinedSQLPlanner for ArrayFunctionPlanner { if left_list_ndims > 0 && right_list_ndims > 0 { if op == sqlparser::ast::BinaryOperator::AtArrow { // array1 @> array2 -> array_has_all(array1, array2) - return Ok(PlannerResult::Simplified(array_has_all(left, right))); + return Ok(PlannerResult::Planned(array_has_all(left, right))); } else { // array1 <@ array2 -> array_has_all(array2, array1) - return Ok(PlannerResult::Simplified(array_has_all(right, left))); + return Ok(PlannerResult::Planned(array_has_all(right, left))); } } } @@ -95,7 +95,7 @@ impl UserDefinedSQLPlanner for ArrayFunctionPlanner { exprs: Vec, _schema: &DFSchema, ) -> Result>> { - Ok(PlannerResult::Simplified(make_array(exprs))) + Ok(PlannerResult::Planned(make_array(exprs))) } } @@ -113,14 +113,14 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner { match field_access { // expr["field"] => get_field(expr, "field") GetFieldAccess::NamedStructField { name } => { - Ok(PlannerResult::Simplified(get_field(expr, name))) + Ok(PlannerResult::Planned(get_field(expr, name))) } // expr[idx] ==> array_element(expr, idx) GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Simplified(Expr::AggregateFunction( + Ok(PlannerResult::Planned(Expr::AggregateFunction( datafusion_expr::expr::AggregateFunction::new( AggregateFunction::NthValue, agg_func @@ -135,7 +135,7 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner { ), ))) } - _ => Ok(PlannerResult::Simplified(array_element(expr, *index))), + _ => Ok(PlannerResult::Planned(array_element(expr, *index))), } } // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) @@ -143,7 +143,7 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner { start, stop, stride, - } => Ok(PlannerResult::Simplified(array_slice( + } => Ok(PlannerResult::Planned(array_slice( expr, *start, *stop, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 7d11e7eb2787..08594f92f483 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -109,7 +109,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; for planner in self.planners.iter() { match planner.plan_binary_op(binary_expr, schema)? { - PlannerResult::Simplified(expr) => { + PlannerResult::Planned(expr) => { return Ok(expr); } PlannerResult::Original(expr) => { @@ -281,7 +281,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; for planner in self.planners.iter() { match planner.plan_field_access(field_access_expr, schema)? { - PlannerResult::Simplified(expr) => return Ok(expr), + PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(expr) => { field_access_expr = expr; } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index c4a765014bd7..5cd6ffc68788 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -156,7 +156,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut exprs = values; for planner in self.planners.iter() { match planner.plan_array_literal(exprs, schema)? { - PlannerResult::Simplified(expr) => { + PlannerResult::Planned(expr) => { return Ok(expr); } PlannerResult::Original(values) => exprs = values, From f1f1b4c7be3fc72268230d7323cb8213b6e6957a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 1 Jul 2024 07:15:54 -0400 Subject: [PATCH 17/17] Update comments and add Debug/Clone impls --- datafusion/expr/src/planner.rs | 9 ++++++--- datafusion/sql/src/expr/mod.rs | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 910b74763835..1febfbec7ef0 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -119,6 +119,7 @@ pub trait UserDefinedSQLPlanner { /// /// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with /// custom expressions. +#[derive(Debug, Clone)] pub struct RawBinaryExpr { pub op: sqlparser::ast::BinaryOperator, pub left: Expr, @@ -129,15 +130,17 @@ pub struct RawBinaryExpr { /// /// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with /// custom expressions. +#[derive(Debug, Clone)] pub struct RawFieldAccessExpr { pub field_access: GetFieldAccess, pub expr: Expr, } +/// Result of planning a raw expr with [`UserDefinedSQLPlanner`] +#[derive(Debug, Clone)] pub enum PlannerResult { - /// The function call was successfully planned as a new Expr + /// The raw expression was successfully planned as a new [`Expr`] Planned(Expr), - /// the function call could not be simplified, and the arguments - /// are return unmodified. + /// The raw expression could not be planned, and is returned unmodified Original(T), } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 08594f92f483..786ea288fa0e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -288,7 +288,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - internal_err!("Expected a simplified result, but none was found") + not_impl_err!("GetFieldAccess not supported by UserDefinedExtensionPlanners: {field_access_expr:?}") } SQLExpr::CompoundIdentifier(ids) => {