Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: function name hints for UDFs #9407

Merged
merged 6 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

struct MyTableSource {
Expand Down
12 changes: 12 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
fn options(&self) -> &ConfigOptions {
self.state.config_options()
}

fn udfs_names(&self) -> Vec<String> {
self.state.scalar_functions().keys().cloned().collect()
}

fn udafs_names(&self) -> Vec<String> {
self.state.aggregate_functions().keys().cloned().collect()
}

fn udwfs_names(&self) -> Vec<String> {
self.state.window_functions().keys().cloned().collect()
}
}

impl FunctionRegistry for SessionState {
Expand Down
37 changes: 3 additions & 34 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Function module contains typing and signature for built-in and user defined functions.

use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
use crate::{
Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::Result;
use std::sync::Arc;
use strum::IntoEnumIterator;

/// Scalar function
///
Expand Down Expand Up @@ -75,33 +74,3 @@ pub fn return_type(
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
fun.signature()
}

/// Suggest a valid function based on an invalid input function name
pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String {
let valid_funcs = if is_window_func {
// All aggregate functions and builtin window functions
AggregateFunction::iter()
.map(|func| func.to_string())
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
.collect()
} else {
// All scalar functions and aggregate functions
BuiltinScalarFunction::iter()
.map(|func| func.to_string())
.chain(AggregateFunction::iter().map(|func| func.to_string()))
.collect()
};
find_closest_match(valid_funcs, input_function_name)
}

/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
/// Input `candidates` must not be empty otherwise it will panic
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
})
.expect("No candidates provided.") // Panic if `candidates` argument is empty
}
12 changes: 12 additions & 0 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

struct MyTableSource {
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
log = { workspace = true }
sqlparser = { workspace = true }
strum = { version = "0.26.1", features = ["derive"] }

[dev-dependencies]
ctor = { workspace = true }
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/examples/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}
58 changes: 53 additions & 5 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,67 @@ use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
};
use datafusion_expr::expr::{ScalarFunction, Unnest};
use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame,
WindowFunctionDefinition,
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
};
use datafusion_expr::{
expr::{ScalarFunction, Unnest},
BuiltInWindowFunction, BuiltinScalarFunction,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType,
};
use std::str::FromStr;
use strum::IntoEnumIterator;

use super::arrow_cast::ARROW_CAST_NAME;

/// Suggest a valid function based on an invalid input function name
pub fn suggest_valid_function(
input_function_name: &str,
is_window_func: bool,
ctx: &dyn ContextProvider,
) -> String {
let valid_funcs = if is_window_func {
// All aggregate functions and builtin window functions
let mut funcs = Vec::new();

funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());
funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udwfs_names());

funcs
} else {
// All scalar functions and aggregate functions
let mut funcs = Vec::new();

funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udfs_names());
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());

funcs
};
find_closest_match(valid_funcs, input_function_name)
}

/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
/// Input `candidates` must not be empty otherwise it will panic
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_common::utils::datafusion_strsim::levenshtein(
&candidate.to_lowercase(),
&target,
)
})
.expect("No candidates provided.") // Panic if `candidates` argument is empty
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_function_to_expr(
&self,
Expand Down Expand Up @@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

// Could not find the relevant function, so return an error
let suggested_func_name = suggest_valid_function(&name, is_function_window);
let suggested_func_name =
suggest_valid_function(&name, is_function_window, self.context_provider);
plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?")
}

Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,18 @@ mod tests {
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ pub trait ContextProvider {

/// Get configuration options
fn options(&self) -> &ConfigOptions;

fn udfs_names(&self) -> Vec<String>;
fn udafs_names(&self) -> Vec<String>;
fn udwfs_names(&self) -> Vec<String>;
}

/// SQL parser options
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,18 @@ impl ContextProvider for MockContextProvider {
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(EmptyTable::new(schema)))
}

fn udfs_names(&self) -> Vec<String> {
self.udfs.keys().cloned().collect()
}

fn udafs_names(&self) -> Vec<String> {
self.udafs.keys().cloned().collect()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

#[test]
Expand Down
Loading