Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get expr planners when creating new planner #11485

Merged
merged 8 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion-examples/examples/sql_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ struct MyContextProvider {
}

impl ContextProvider for MyContextProvider {
fn get_expr_planners(&self) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
if name.table() == "person" {
Ok(Arc::new(MyTableSource {
Expand Down
26 changes: 10 additions & 16 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ impl SessionState {
}
}

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.statement_to_plan(statement)
}

Expand Down Expand Up @@ -569,7 +569,7 @@ impl SessionState {
tables: HashMap::new(),
};

let query = self.build_sql_query_planner(&provider);
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
}

Expand Down Expand Up @@ -854,20 +854,6 @@ impl SessionState {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}

fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S>
where
S: ContextProvider,
{
let mut query = SqlToRel::new_with_options(provider, self.get_parser_options());

// custom planners are registered first, so they're run first and take precedence over built-in planners
for planner in self.expr_planners.iter() {
query = query.with_user_defined_planner(planner.clone());
}

query
}
}

/// A builder to be used for building [`SessionState`]'s. Defaults will
Expand Down Expand Up @@ -1597,12 +1583,20 @@ impl SessionStateDefaults {
}
}

/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`]
///
/// This is used so the SQL planner can access the state of the session without
/// having a direct dependency on the [`SessionState`] struct (and core crate)
struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
}

impl<'a> ContextProvider for SessionContextProvider<'a> {
fn get_expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
self.state.expr_planners()
}

fn get_table_source(
&self,
name: TableReference,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ impl MyContextProvider {
}

impl ContextProvider for MyContextProvider {
fn get_expr_planners(&self) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let table_name = name.table();
if table_name.starts_with("test") {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pub trait ContextProvider {
not_impl_err!("Recursive CTE is not implemented")
}

/// Getter for expr planners
fn get_expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time trying to figure out if we could avoid this Vec cloning all the time -- with

Suggested change
fn get_expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>>;
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {&[]}

I think this is possible -- I made jayzhan211#3 to show how it looks


/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
Expand Down
4 changes: 4 additions & 0 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ impl MyContextProvider {
}

impl ContextProvider for MyContextProvider {
fn get_expr_planners(&self) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let table_name = name.table();
if table_name.starts_with("test") {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sql/examples/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
}

impl ContextProvider for MyContextProvider {
fn get_expr_planners(&self) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
match self.tables.get(name.table()) {
Some(table) => Ok(Arc::clone(table)),
Expand Down
20 changes: 13 additions & 7 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Expr> {
// try extension planers
let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_binary_op(binary_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(*expr, schema, planner_context)?,
];

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_extract(extract_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -283,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
};

let mut field_access_expr = RawFieldAccessExpr { expr, field_access };
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_field_access(field_access_expr, schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
Expand Down Expand Up @@ -653,7 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.create_struct_expr(values, schema, planner_context)?
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_struct_literal(create_struct_args, is_named_struct)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => create_struct_args = args,
Expand All @@ -673,7 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?;
let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?;
let mut position_args = vec![fullstr, substr];
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_position(position_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down Expand Up @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let mut raw_expr = RawDictionaryExpr { keys, values };

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_dictionary_literal(raw_expr, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down Expand Up @@ -927,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
None => vec![arg, what_arg, from_arg],
};
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_overlay(overlay_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => overlay_args = args,
Expand Down Expand Up @@ -979,6 +979,12 @@ mod tests {
}

impl ContextProvider for TestContextProvider {
fn get_expr_planners(
&self,
) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
match self.tables.get(name.table()) {
Some(table) => Ok(Arc::clone(table)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/substring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
};

for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_substring(substring_args)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(args) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<Expr> {
let mut exprs = values;
for planner in self.planners.iter() {
for planner in self.context_provider.get_expr_planners().iter() {
match planner.plan_array_literal(exprs, schema)? {
PlannerResult::Planned(expr) => {
return Ok(expr);
Expand Down
10 changes: 0 additions & 10 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow_schema::*;
use datafusion_common::{
field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError,
};
use datafusion_expr::planner::ExprPlanner;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
Expand Down Expand Up @@ -186,8 +185,6 @@ pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) normalizer: IdentNormalizer,
/// user defined planner extensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

pub(crate) planners: Vec<Arc<dyn ExprPlanner>>,
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -196,12 +193,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Self::new_with_options(context_provider, ParserOptions::default())
}

/// add an user defined planner
pub fn with_user_defined_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can register planner in session state

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to use the underlying session state as the source of truth

self.planners.push(planner);
self
}

/// Create a new query planner
pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
let normalize = options.enable_ident_normalization;
Expand All @@ -210,7 +201,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
context_provider,
options,
normalizer: IdentNormalizer::new(normalize),
planners: vec![],
}
}

Expand Down
4 changes: 4 additions & 0 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ impl MockContextProvider {
}

impl ContextProvider for MockContextProvider {
fn get_expr_planners(&self) -> Vec<Arc<dyn datafusion_expr::planner::ExprPlanner>> {
vec![]
}

fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
let schema = match name.table() {
"test" => Ok(Schema::new(vec![
Expand Down