Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get expr planners when creating new planner #11485

Merged
merged 8 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl SessionState {
}
}

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.statement_to_plan(statement)
}

Expand Down Expand Up @@ -569,7 +569,7 @@ impl SessionState {
tables: HashMap::new(),
};

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
}

Expand Down Expand Up @@ -854,20 +854,6 @@ impl SessionState {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}

fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S>
where
S: ContextProvider,
{
let mut query = SqlToRel::new_with_options(provider, self.get_parser_options());

// custom planners are registered first, so they're run first and take precedence over built-in planners
for planner in self.expr_planners.iter() {
query = query.with_user_defined_planner(planner.clone());
}

query
}
}

/// A builder to be used for building [`SessionState`]'s. Defaults will
Expand Down Expand Up @@ -1597,12 +1583,20 @@ impl SessionStateDefaults {
}
}

/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`]
///
/// This is used so the SQL planner can access the state of the session without
/// having a direct dependency on the [`SessionState`] struct (and core crate)
struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
}

impl<'a> ContextProvider for SessionContextProvider<'a> {
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remove expr_planners field in SessionContextProvider and get it from state directly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I tried to figure out how to do this previously and got confused about a lock or something. This is very nice 👌

}

fn get_table_source(
&self,
name: TableReference,
Expand Down Expand Up @@ -1898,3 +1892,47 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
expr.get_type(self.df_schema)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use arrow_schema::{DataType, Field, Schema};
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_expr::Expr;
use datafusion_sql::planner::{PlannerContext, SqlToRel};

use crate::execution::context::SessionState;

use super::{SessionContextProvider, SessionStateBuilder};

#[test]
fn test_session_state_with_default_features() {
// test array planners with and without builtin planners
fn sql_to_expr(state: &SessionState) -> Result<Expr> {
let provider = SessionContextProvider {
state,
tables: HashMap::new(),
};

let sql = "[1,2,3]";
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let df_schema = DFSchema::try_from(schema)?;
let dialect = state.config.options().sql_parser.dialect.as_str();
let sql_expr = state.sql_to_expr(sql, dialect)?;

let query = SqlToRel::new_with_options(&provider, state.get_parser_options());
query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())
}

let state = SessionStateBuilder::new().with_default_features().build();

assert!(sql_to_expr(&state).is_ok());

// if no builtin planners exist, you should register your own, otherwise returns error
let state = SessionStateBuilder::new().build();

assert!(sql_to_expr(&state).is_err())
}
}
5 changes: 5 additions & 0 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ pub trait ContextProvider {
not_impl_err!("Recursive CTE is not implemented")
}

/// Getter for expr planners
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&[]
}

/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Expr> {
// try extension planers
let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_binary_op(binary_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(*expr, schema, planner_context)?,
];

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_extract(extract_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -283,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
};

let mut field_access_expr = RawFieldAccessExpr { expr, field_access };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_field_access(field_access_expr, schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
Expand Down Expand Up @@ -653,7 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.create_struct_expr(values, schema, planner_context)?
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_struct_literal(create_struct_args, is_named_struct)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => create_struct_args = args,
Expand All @@ -673,7 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?;
let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?;
let mut position_args = vec![fullstr, substr];
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_position(position_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let mut raw_expr = RawDictionaryExpr { keys, values };

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_dictionary_literal(raw_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -927,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
None => vec![arg, what_arg, from_arg],
};
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_overlay(overlay_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => overlay_args = args,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/substring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_substring(substring_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<Expr> {
let mut exprs = values;
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners() {
match planner.plan_array_literal(exprs, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down
10 changes: 0 additions & 10 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow_schema::*;
use datafusion_common::{
field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError,
};
use datafusion_expr::planner::ExprPlanner;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
Expand Down Expand Up @@ -186,8 +185,6 @@ pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) normalizer: IdentNormalizer,
/// user defined planner extensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

pub(crate) planners: Vec<Arc<dyn ExprPlanner>>,
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -196,12 +193,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Self::new_with_options(context_provider, ParserOptions::default())
}

/// add an user defined planner
pub fn with_user_defined_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can register planner in session state

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to use the underlying session state as the source of truth

self.planners.push(planner);
self
}

/// Create a new query planner
pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
let normalize = options.enable_ident_normalization;
Expand All @@ -210,7 +201,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
context_provider,
options,
normalizer: IdentNormalizer::new(normalize),
planners: vec![],
}
}

Expand Down
Loading