diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 75eef4345487..03ce8d3b5892 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -516,7 +516,7 @@ impl SessionState { } } - let query = self.build_sql_query_planner(&provider); + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); query.statement_to_plan(statement) } @@ -569,7 +569,7 @@ impl SessionState { tables: HashMap::new(), }; - let query = self.build_sql_query_planner(&provider); + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) } @@ -854,20 +854,6 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } - - fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S> - where - S: ContextProvider, - { - let mut query = SqlToRel::new_with_options(provider, self.get_parser_options()); - - // custom planners are registered first, so they're run first and take precedence over built-in planners - for planner in self.expr_planners.iter() { - query = query.with_user_defined_planner(planner.clone()); - } - - query - } } /// A builder to be used for building [`SessionState`]'s. Defaults will @@ -1597,12 +1583,20 @@ impl SessionStateDefaults { } } +/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] +/// +/// This is used so the SQL planner can access the state of the session without +/// having a direct dependency on the [`SessionState`] struct (and core crate) struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, } impl<'a> ContextProvider for SessionContextProvider<'a> { + fn get_expr_planners(&self) -> &[Arc] { + &self.state.expr_planners + } + fn get_table_source( &self, name: TableReference, @@ -1898,3 +1892,47 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { expr.get_type(self.df_schema) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::DFSchema; + use datafusion_common::Result; + use datafusion_expr::Expr; + use datafusion_sql::planner::{PlannerContext, SqlToRel}; + + use crate::execution::context::SessionState; + + use super::{SessionContextProvider, SessionStateBuilder}; + + #[test] + fn test_session_state_with_default_features() { + // test array planners with and without builtin planners + fn sql_to_expr(state: &SessionState) -> Result { + let provider = SessionContextProvider { + state, + tables: HashMap::new(), + }; + + let sql = "[1,2,3]"; + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema)?; + let dialect = state.config.options().sql_parser.dialect.as_str(); + let sql_expr = state.sql_to_expr(sql, dialect)?; + + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + } + + let state = SessionStateBuilder::new().with_default_features().build(); + + assert!(sql_to_expr(&state).is_ok()); + + // if no builtin planners exist, you should register your own, otherwise returns error + let state = SessionStateBuilder::new().build(); + + assert!(sql_to_expr(&state).is_err()) + } +} diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 2f13923b1f10..009f3512c588 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -60,6 +60,11 @@ pub trait ContextProvider { not_impl_err!("Recursive CTE is not implemented") } + /// Getter for expr planners + fn get_expr_planners(&self) -> &[Arc] { + &[] + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 062ef805fd9f..71ff7c03bea2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -111,7 +111,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // try extension planers let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_binary_op(binary_expr, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_extract(extract_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { @@ -283,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_field_access(field_access_expr, schema)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(expr) => { @@ -653,7 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.create_struct_expr(values, schema, planner_context)? }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_struct_literal(create_struct_args, is_named_struct)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => create_struct_args = args, @@ -673,7 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let mut position_args = vec![fullstr, substr]; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_position(position_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut raw_expr = RawDictionaryExpr { keys, values }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_dictionary_literal(raw_expr, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); @@ -927,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_overlay(overlay_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => overlay_args = args, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a0dfee1b9d90..f58ab5ff3612 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_substring(substring_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 5cd6ffc68788..1564f06fe4b9 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result { let mut exprs = values; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_array_literal(exprs, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index be04f51f4f2c..901a2ad38d8c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -24,7 +24,6 @@ use arrow_schema::*; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::planner::ExprPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -186,8 +185,6 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, - /// user defined planner extensions - pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -196,12 +193,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::new_with_options(context_provider, ParserOptions::default()) } - /// add an user defined planner - pub fn with_user_defined_planner(mut self, planner: Arc) -> Self { - self.planners.push(planner); - self - } - /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; @@ -210,7 +201,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { context_provider, options, normalizer: IdentNormalizer::new(normalize), - planners: vec![], } }