diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5e3c8648fc25..b4af7896821b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1363,6 +1363,7 @@ dependencies = [ "datafusion-expr", "log", "sqlparser", + "strum 0.26.1", ] [[package]] diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index cc1396f770e4..541448ebf149 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs_names(&self) -> Vec { + Vec::new() + } + + fn udafs_names(&self) -> Vec { + Vec::new() + } + + fn udwfs_names(&self) -> Vec { + Vec::new() + } } struct MyTableSource { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 7b37e4914cf9..49d1b12e6646 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2098,6 +2098,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { fn options(&self) -> &ConfigOptions { self.state.config_options() } + + fn udfs_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udafs_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwfs_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } } impl FunctionRegistry for SessionState { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3e30a5574be0..a3760eeb357d 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,13 +17,12 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; -use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use crate::{ + Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature, +}; use arrow::datatypes::DataType; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; -use strum::IntoEnumIterator; /// Scalar function /// @@ -75,33 +74,3 @@ pub fn return_type( pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.signature() } - -/// Suggest a valid function based on an invalid input function name -pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { - let valid_funcs = if is_window_func { - // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() - } else { - // All scalar functions and aggregate functions - BuiltinScalarFunction::iter() - .map(|func| func.to_string()) - .chain(AggregateFunction::iter().map(|func| func.to_string())) - .collect() - }; - find_closest_match(valid_funcs, input_function_name) -} - -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { - let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) - }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty -} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index db7bfa8b3bc8..b02623854b8a 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs_names(&self) -> Vec { + Vec::new() + } + + fn udafs_names(&self) -> Vec { + Vec::new() + } + + fn udwfs_names(&self) -> Vec { + Vec::new() + } } struct MyTableSource { diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index fb300e2c8791..7739058a5c9d 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } log = { workspace = true } sqlparser = { workspace = true } +strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 8744a905481f..5bab2f19cfc0 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs_names(&self) -> Vec { + Vec::new() + } + + fn udafs_names(&self) -> Vec { + Vec::new() + } + + fn udwfs_names(&self) -> Vec { + Vec::new() + } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bcf641e4b5a0..ffc951a6fa66 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -20,20 +20,67 @@ use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; -use datafusion_expr::expr::{ScalarFunction, Unnest}; -use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame, - WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, +}; +use datafusion_expr::{ + expr::{ScalarFunction, Unnest}, + BuiltInWindowFunction, BuiltinScalarFunction, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, }; use std::str::FromStr; +use strum::IntoEnumIterator; use super::arrow_cast::ARROW_CAST_NAME; +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function( + input_function_name: &str, + is_window_func: bool, + ctx: &dyn ContextProvider, +) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + let mut funcs = Vec::new(); + + funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udafs_names()); + funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udwfs_names()); + + funcs + } else { + // All scalar functions and aggregate functions + let mut funcs = Vec::new(); + + funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udfs_names()); + funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udafs_names()); + + funcs + }; + find_closest_match(valid_funcs, input_function_name) +} + +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_common::utils::datafusion_strsim::levenshtein( + &candidate.to_lowercase(), + &target, + ) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // Could not find the relevant function, so return an error - let suggested_func_name = suggest_valid_function(&name, is_function_window); + let suggested_func_name = + suggest_valid_function(&name, is_function_window, self.context_provider); plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d6aa006ec3b3..e838a4cafb2a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -983,6 +983,18 @@ mod tests { fn get_window_meta(&self, _name: &str) -> Option> { None } + + fn udfs_names(&self) -> Vec { + Vec::new() + } + + fn udafs_names(&self) -> Vec { + Vec::new() + } + + fn udwfs_names(&self) -> Vec { + Vec::new() + } } fn create_table_source(fields: Vec) -> Arc { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2db2c01c5ee1..f94c6ec4e8c9 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -85,6 +85,10 @@ pub trait ContextProvider { /// Get configuration options fn options(&self) -> &ConfigOptions; + + fn udfs_names(&self) -> Vec; + fn udafs_names(&self) -> Vec; + fn udwfs_names(&self) -> Vec; } /// SQL parser options diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 655eb63cc380..6681c3d02564 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2901,6 +2901,18 @@ impl ContextProvider for MockContextProvider { ) -> Result> { Ok(Arc::new(EmptyTable::new(schema))) } + + fn udfs_names(&self) -> Vec { + self.udfs.keys().cloned().collect() + } + + fn udafs_names(&self) -> Vec { + self.udafs.keys().cloned().collect() + } + + fn udwfs_names(&self) -> Vec { + Vec::new() + } } #[test] diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 96aa3e275209..21433ba16810 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'? SELECT arrowtypeof(v1) from test; # Scalar function -statement error Invalid function 'to_timestamps_second' +statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function