diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e6..9cd7d7e03326 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -16,13 +16,14 @@ // under the License. use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::execution::FunctionRegistry; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; +use datafusion_optimizer::analyzer::{Analyzer, AnalyzerConfig, AnalyzerRule}; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; @@ -32,6 +33,38 @@ use datafusion_sql::TableReference; use std::any::Any; use std::sync::Arc; +struct ExamplesAnalyzerConfig<'a> { + config_options: &'a ConfigOptions, +} + +impl<'a> FunctionRegistry for ExamplesAnalyzerConfig<'a> { + fn udfs(&self) -> std::collections::HashSet { + std::collections::HashSet::new() + } + + fn udf(&self, _name: &str) -> Result> { + internal_err!("Mock Function Registry") + } + + fn udaf(&self, _name: &str) -> Result> { + internal_err!("Mock Function Registry") + } + + fn udwf(&self, _name: &str) -> Result> { + internal_err!("Mock Function Registry") + } +} + +impl<'a> AnalyzerConfig for ExamplesAnalyzerConfig<'a> { + fn function_registry(&self) -> &dyn FunctionRegistry { + self + } + + fn options(&self) -> &ConfigOptions { + self.config_options + } +} + pub fn main() -> Result<()> { // produce a logical plan using the datafusion-sql crate let dialect = PostgreSqlDialect {}; @@ -50,8 +83,11 @@ pub fn main() -> Result<()> { // run the analyzer with our custom rule let config = OptimizerContext::default().with_skip_failing_rules(false); let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); + let analyzer_config = ExamplesAnalyzerConfig { + config_options: config.options(), + }; let analyzed_plan = - analyzer.execute_and_check(&logical_plan, config.options(), |_, _| {})?; + analyzer.execute_and_check(&logical_plan, &analyzer_config, |_, _| {})?; println!( "Analyzed Logical Plan:\n\n{}\n", analyzed_plan.display_indent() @@ -80,7 +116,11 @@ fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { struct MyAnalyzerRule {} impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + fn analyze( + &self, + plan: LogicalPlan, + _config: &dyn AnalyzerConfig, + ) -> Result { Self::analyze_plan(plan) } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 65414f5619a5..00360aef0027 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -349,12 +349,15 @@ mod tests { use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; + use datafusion_common::internal_err; use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; use datafusion_common::{DataFusionError, Result}; + use datafusion_execution::FunctionRegistry; use datafusion_expr::{ builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF, }; + use datafusion_optimizer::analyzer::AnalyzerConfig; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_sql::planner::ContextProvider; @@ -370,6 +373,38 @@ mod tests { use std::ops::Rem; use std::sync::Arc; + struct TestAnalyzerConfig<'a> { + config_options: &'a ConfigOptions, + } + + impl<'a> FunctionRegistry for TestAnalyzerConfig<'a> { + fn udfs(&self) -> HashSet { + HashSet::new() + } + + fn udf(&self, _name: &str) -> Result> { + internal_err!("mock function registry") + } + + fn udaf(&self, _name: &str) -> Result> { + internal_err!("mock function registry") + } + + fn udwf(&self, _name: &str) -> Result> { + internal_err!("mock function registry") + } + } + + impl<'a> AnalyzerConfig for TestAnalyzerConfig<'a> { + fn function_registry(&self) -> &dyn datafusion_execution::FunctionRegistry { + self + } + + fn options(&self) -> &ConfigOptions { + self.config_options + } + } + struct PrimitiveTypeField { name: &'static str, physical_ty: PhysicalType, @@ -1314,7 +1349,10 @@ mod tests { let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; + let analyzer_config = TestAnalyzerConfig { + config_options: config.options(), + }; + let plan = analyzer.execute_and_check(&plan, &analyzer_config, |_, _| {})?; let plan = optimizer.optimize(&plan, &config, |_, _| {})?; // convert the logical plan into a physical plan let exprs = plan.expressions(); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 58a4f08341d6..626061fb81c5 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -104,7 +104,7 @@ use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_ use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use datafusion_optimizer::{ - analyzer::{Analyzer, AnalyzerRule}, + analyzer::{Analyzer, AnalyzerConfig, AnalyzerRule}, OptimizerConfig, }; use datafusion_sql::planner::object_name_to_table_reference; @@ -1729,7 +1729,7 @@ impl SessionState { // analyze & capture output of each rule let analyzed_plan = match self.analyzer.execute_and_check( e.plan.as_ref(), - self.options(), + self, |analyzed_plan, analyzer| { let analyzer_name = analyzer.name().to_string(); let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -1785,9 +1785,7 @@ impl SessionState { logical_optimization_succeeded, })) } else { - let analyzed_plan = - self.analyzer - .execute_and_check(plan, self.options(), |_, _| {})?; + let analyzed_plan = self.analyzer.execute_and_check(plan, self, |_, _| {})?; self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) } } @@ -1875,6 +1873,16 @@ impl SessionState { } } +impl AnalyzerConfig for SessionState { + fn function_registry(&self) -> &dyn FunctionRegistry { + self + } + + fn options(&self) -> &ConfigOptions { + self.config_options() + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cedf1d845137..0f53215e9230 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -27,7 +27,8 @@ use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, - ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, + ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, + StateTypeFunction, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; @@ -1007,7 +1008,7 @@ pub fn create_udwf( ) } -/// Calls a named built in function +/// Calls a named function /// ``` /// use datafusion_expr::{col, lit, call_fn}; /// @@ -1015,10 +1016,10 @@ pub fn create_udwf( /// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2)); /// ``` pub fn call_fn(name: impl AsRef, args: Vec) -> Result { - match name.as_ref().parse::() { - Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))), - Err(e) => Err(e), - } + Ok(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::Name(Arc::from(name.as_ref())), + args, + })) } #[cfg(test)] diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index fac880867fef..676cc2a11c5b 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -43,6 +43,7 @@ arrow = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..2719824578fc 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::AnalyzerRule; -use datafusion_common::config::ConfigOptions; + use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; @@ -29,6 +29,8 @@ use datafusion_expr::{ }; use std::sync::Arc; +use super::AnalyzerConfig; + /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// /// Resolves issue: @@ -42,7 +44,7 @@ impl CountWildcardRule { } impl AnalyzerRule for CountWildcardRule { - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + fn analyze(&self, plan: LogicalPlan, _: &dyn AnalyzerConfig) -> Result { plan.transform_down(&analyze_internal) } diff --git a/datafusion/optimizer/src/analyzer/function_name_resolver.rs b/datafusion/optimizer/src/analyzer/function_name_resolver.rs new file mode 100644 index 000000000000..d665f520d2d2 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/function_name_resolver.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; + +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::DataFusionError; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; + +use datafusion_execution::FunctionRegistry; +use datafusion_expr::{logical_plan::LogicalPlan, Expr}; +use std::str::FromStr; + +use crate::analyzer::AnalyzerConfig; + +/// Resolves `ScalarFunctionDefinition::Name` at execution time. +/// + +pub struct ResolveFunctionByName {} + +impl ResolveFunctionByName { + pub fn new() -> Self { + ResolveFunctionByName {} + } +} + +impl Default for ResolveFunctionByName { + fn default() -> Self { + Self::new() + } +} + +impl AnalyzerRule for ResolveFunctionByName { + fn analyze( + &self, + plan: LogicalPlan, + config: &dyn AnalyzerConfig, + ) -> Result { + analyze_internal(&plan, config.function_registry()) + } + + fn name(&self) -> &str { + "resolve_function_by_name" + } +} + +fn analyze_internal( + plan: &LogicalPlan, + registry: &dyn FunctionRegistry, +) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p, registry)) + .collect::>>()?; + + let mut expr_rewrite = FunctionResolverRewriter { registry }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| rewrite_preserving_name(expr, &mut expr_rewrite)) + .collect::>>()?; + plan.with_new_exprs(new_expr, &new_inputs) +} + +struct FunctionResolverRewriter<'a> { + registry: &'a dyn FunctionRegistry, +} + +impl<'a> TreeNodeRewriter for FunctionResolverRewriter<'a> { + type N = Expr; + + fn mutate(&mut self, old_expr: Expr) -> Result { + match old_expr.clone() { + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::Name(name), + args, + }) => { + // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function + if let Ok(fm) = self.registry.udf(&name) { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); + } + + // next, scalar built-in + if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { + return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); + } + internal_err!("Unknown scalar function") + } + _ => Ok(old_expr), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + use datafusion_expr::{ + ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, Signature, + Volatility, + }; + use std::sync::Arc; + + use super::*; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockRegistry {} + + impl FunctionRegistry for MockRegistry { + fn udfs(&self) -> std::collections::HashSet { + todo!() + } + + fn udf(&self, name: &str) -> Result> { + if name == "my-udf" { + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = Arc::new(move |_| { + Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))) + }); + return Ok(Arc::new(datafusion_expr::ScalarUDF::new( + "my-udf", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + ))); + } + internal_err!("function not found") + } + + fn udaf(&self, _name: &str) -> Result> { + todo!() + } + + fn udwf(&self, _name: &str) -> Result> { + todo!() + } + } + + fn rewrite(function: ScalarFunction) -> Result { + let registry = MockRegistry {}; + let mut rewriter = FunctionResolverRewriter { + registry: ®istry, + }; + rewriter.mutate(Expr::ScalarFunction(function)) + } + + #[test] + fn rewriter_rewrites_builtin_correctly() { + let function = datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::Name(Arc::from("log10")), + args: vec![], + }; + let result = rewrite(function); + assert!(matches!( + result, + Ok(Expr::ScalarFunction( + datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::Log + ), + .. + } + )) + )); + } + #[test] + fn rewriter_rewrites_udf_correctly() { + let function = datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::Name(Arc::from("my-udf")), + args: vec![], + }; + let result = rewrite(function); + if let Ok(Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(udf), + .. + })) = result + { + assert_eq!(udf.name(), "my-udf"); + } else { + panic!("Pattern did not match"); + } + } + #[test] + fn rewriter_fails_unknown_function() { + let function = datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::Name(Arc::from("my-udf-invalid")), + args: vec![], + }; + let result = rewrite(function); + assert!(result.is_err()); + } +} diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..d05b0b144d98 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; -use datafusion_common::config::ConfigOptions; + use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result; use datafusion_expr::expr::Exists; @@ -29,6 +29,8 @@ use datafusion_expr::{ logical_plan::LogicalPlan, Expr, Filter, LogicalPlanBuilder, TableScan, }; +use crate::analyzer::AnalyzerConfig; + /// Analyzed rule that inlines TableScan that provide a [`LogicalPlan`] /// (DataFrame / ViewTable) #[derive(Default)] @@ -41,7 +43,7 @@ impl InlineTableScan { } impl AnalyzerRule for InlineTableScan { - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + fn analyze(&self, plan: LogicalPlan, _: &dyn AnalyzerConfig) -> Result { plan.transform_up(&analyze_internal) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..4f9af220cb93 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -16,11 +16,13 @@ // under the License. pub mod count_wildcard_rule; +pub mod function_name_resolver; pub mod inline_table_scan; pub mod subquery; pub mod type_coercion; use crate::analyzer::count_wildcard_rule::CountWildcardRule; +use crate::analyzer::function_name_resolver::ResolveFunctionByName; use crate::analyzer::inline_table_scan::InlineTableScan; use crate::analyzer::subquery::check_subquery_expr; @@ -29,6 +31,7 @@ use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::FunctionRegistry; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; use datafusion_expr::utils::inspect_expr_pre; @@ -37,6 +40,14 @@ use log::debug; use std::sync::Arc; use std::time::Instant; +/// Options to control DataFusion Analyzer Passes. +pub trait AnalyzerConfig { + /// Return a function registry for resolving names + fn function_registry(&self) -> &dyn FunctionRegistry; + /// return datafusion configuration options + fn options(&self) -> &ConfigOptions; +} + /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// @@ -49,7 +60,11 @@ use std::time::Instant; /// it the same result in some more optimal way. pub trait AnalyzerRule { /// Rewrite `plan` - fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result; + fn analyze( + &self, + plan: LogicalPlan, + config: &dyn AnalyzerConfig, + ) -> Result; /// A human readable name for this analyzer rule fn name(&self) -> &str; @@ -74,6 +89,7 @@ impl Analyzer { Arc::new(InlineTableScan::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), + Arc::new(ResolveFunctionByName::new()), ]; Self::with_rules(rules) } @@ -88,7 +104,7 @@ impl Analyzer { pub fn execute_and_check( &self, plan: &LogicalPlan, - config: &ConfigOptions, + config: &dyn AnalyzerConfig, mut observer: F, ) -> Result where diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd..9472d1e99c3f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; -use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, @@ -52,6 +51,8 @@ use datafusion_expr::{ use crate::analyzer::AnalyzerRule; +use crate::analyzer::AnalyzerConfig; + #[derive(Default)] pub struct TypeCoercion {} @@ -66,7 +67,7 @@ impl AnalyzerRule for TypeCoercion { "type_coercion" } - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + fn analyze(&self, plan: LogicalPlan, _: &dyn AnalyzerConfig) -> Result { analyze_internal(&DFSchema::empty(), &plan) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a5351..1b79b0b9e183 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::analyzer::{Analyzer, AnalyzerRule}; +use crate::analyzer::{Analyzer, AnalyzerConfig, AnalyzerRule}; use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{assert_contains, Result}; +use datafusion_common::{assert_contains, internal_err, Result}; +use datafusion_execution::FunctionRegistry; use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; + +use datafusion_common::DataFusionError; use std::sync::Arc; pub mod user_defined; @@ -108,14 +111,56 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } +struct EmptyRegistryAnalyzerConfig { + options: ConfigOptions, +} + +impl EmptyRegistryAnalyzerConfig { + fn new() -> Self { + let options = ConfigOptions::default(); + EmptyRegistryAnalyzerConfig { options } + } +} + +impl FunctionRegistry for EmptyRegistryAnalyzerConfig { + fn udfs(&self) -> std::collections::HashSet { + std::collections::HashSet::new() + } + + fn udf(&self, _name: &str) -> Result> { + internal_err!("empty registry") + } + + fn udaf(&self, _name: &str) -> Result> { + internal_err!("empty registry") + } + + fn udwf(&self, _name: &str) -> Result> { + internal_err!("empty registry") + } +} + +impl AnalyzerConfig for EmptyRegistryAnalyzerConfig { + fn function_registry(&self) -> &dyn FunctionRegistry { + self + } + + fn options(&self) -> &ConfigOptions { + &self.options + } +} + pub fn assert_analyzed_plan_eq( rule: Arc, plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; + let analyzer_config = EmptyRegistryAnalyzerConfig::new(); + let analyzed_plan = Analyzer::with_rules(vec![rule]).execute_and_check( + plan, + &analyzer_config, + |_, _| {}, + )?; let formatted_plan = format!("{analyzed_plan:?}"); assert_eq!(formatted_plan, expected); @@ -126,9 +171,12 @@ pub fn assert_analyzed_plan_eq_display_indent( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; + let analyzer_config = EmptyRegistryAnalyzerConfig::new(); + let analyzed_plan = Analyzer::with_rules(vec![rule]).execute_and_check( + plan, + &analyzer_config, + |_, _| {}, + )?; let formatted_plan = analyzed_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); @@ -140,9 +188,9 @@ pub fn assert_analyzer_check_err( plan: &LogicalPlan, expected: &str, ) { - let options = ConfigOptions::default(); + let analyzer_config = EmptyRegistryAnalyzerConfig::new(); let analyzed_plan = - Analyzer::with_rules(rules).execute_and_check(plan, &options, |_, _| {}); + Analyzer::with_rules(rules).execute_and_check(plan, &analyzer_config, |_, _| {}); match analyzed_plan { Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), Err(e) => { diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index d857c6154ea9..785b84691662 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -16,14 +16,15 @@ // under the License. use std::any::Any; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; +use datafusion_execution::FunctionRegistry; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; -use datafusion_optimizer::analyzer::Analyzer; +use datafusion_optimizer::analyzer::{Analyzer, AnalyzerConfig}; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; @@ -335,6 +336,38 @@ fn test_same_name_but_not_ambiguous() { assert_eq!(expected, format!("{plan:?}")); } +struct EmptyRegistryAnalyzerConfig<'a> { + config_options: &'a ConfigOptions, +} + +impl<'a> FunctionRegistry for EmptyRegistryAnalyzerConfig<'a> { + fn udfs(&self) -> std::collections::HashSet { + HashSet::new() + } + + fn udf(&self, _name: &str) -> Result> { + internal_err!("Mock function registry") + } + + fn udaf(&self, _name: &str) -> Result> { + internal_err!("Mock function registry") + } + + fn udwf(&self, _name: &str) -> Result> { + internal_err!("Mock function registry") + } +} + +impl<'a> AnalyzerConfig for EmptyRegistryAnalyzerConfig<'a> { + fn function_registry(&self) -> &dyn FunctionRegistry { + self + } + + fn options(&self) -> &ConfigOptions { + self.config_options + } +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -354,8 +387,11 @@ fn test_sql(sql: &str) -> Result { .with_query_execution_start_time(now_time); let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); + let analyzer_config = EmptyRegistryAnalyzerConfig { + config_options: config.options(), + }; // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(&plan, &analyzer_config, |_, _| {})?; optimizer.optimize(&plan, &config, |_, _| {}) }