Skip to content

Commit

Permalink
feat: function name hints for UDFs (apache#9407)
Browse files Browse the repository at this point in the history
* feat: function name hints for UDFs

* refactor: rebase fn to xxx_names()

* style: fix clippy

* style: fix clippy

* Add test

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
SteveLauC and alamb authored Mar 10, 2024
1 parent 96664ce commit f1f0965
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 40 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

struct MyTableSource {
Expand Down
12 changes: 12 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
fn options(&self) -> &ConfigOptions {
self.state.config_options()
}

fn udfs_names(&self) -> Vec<String> {
self.state.scalar_functions().keys().cloned().collect()
}

fn udafs_names(&self) -> Vec<String> {
self.state.aggregate_functions().keys().cloned().collect()
}

fn udwfs_names(&self) -> Vec<String> {
self.state.window_functions().keys().cloned().collect()
}
}

impl FunctionRegistry for SessionState {
Expand Down
37 changes: 3 additions & 34 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Function module contains typing and signature for built-in and user defined functions.
use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
use crate::{
Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::Result;
use std::sync::Arc;
use strum::IntoEnumIterator;

/// Scalar function
///
Expand Down Expand Up @@ -75,33 +74,3 @@ pub fn return_type(
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
fun.signature()
}

/// Suggest a valid function based on an invalid input function name
pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String {
let valid_funcs = if is_window_func {
// All aggregate functions and builtin window functions
AggregateFunction::iter()
.map(|func| func.to_string())
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
.collect()
} else {
// All scalar functions and aggregate functions
BuiltinScalarFunction::iter()
.map(|func| func.to_string())
.chain(AggregateFunction::iter().map(|func| func.to_string()))
.collect()
};
find_closest_match(valid_funcs, input_function_name)
}

/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
/// Input `candidates` must not be empty otherwise it will panic
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
})
.expect("No candidates provided.") // Panic if `candidates` argument is empty
}
12 changes: 12 additions & 0 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

struct MyTableSource {
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
log = { workspace = true }
sqlparser = { workspace = true }
strum = { version = "0.26.1", features = ["derive"] }

[dev-dependencies]
ctor = { workspace = true }
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/examples/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
fn options(&self) -> &ConfigOptions {
&self.options
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}
58 changes: 53 additions & 5 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,67 @@ use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
};
use datafusion_expr::expr::{ScalarFunction, Unnest};
use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame,
WindowFunctionDefinition,
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
};
use datafusion_expr::{
expr::{ScalarFunction, Unnest},
BuiltInWindowFunction, BuiltinScalarFunction,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType,
};
use std::str::FromStr;
use strum::IntoEnumIterator;

use super::arrow_cast::ARROW_CAST_NAME;

/// Suggest a valid function based on an invalid input function name
pub fn suggest_valid_function(
input_function_name: &str,
is_window_func: bool,
ctx: &dyn ContextProvider,
) -> String {
let valid_funcs = if is_window_func {
// All aggregate functions and builtin window functions
let mut funcs = Vec::new();

funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());
funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udwfs_names());

funcs
} else {
// All scalar functions and aggregate functions
let mut funcs = Vec::new();

funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udfs_names());
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());

funcs
};
find_closest_match(valid_funcs, input_function_name)
}

/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
/// Input `candidates` must not be empty otherwise it will panic
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_common::utils::datafusion_strsim::levenshtein(
&candidate.to_lowercase(),
&target,
)
})
.expect("No candidates provided.") // Panic if `candidates` argument is empty
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_function_to_expr(
&self,
Expand Down Expand Up @@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

// Could not find the relevant function, so return an error
let suggested_func_name = suggest_valid_function(&name, is_function_window);
let suggested_func_name =
suggest_valid_function(&name, is_function_window, self.context_provider);
plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?")
}

Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,18 @@ mod tests {
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}

fn udfs_names(&self) -> Vec<String> {
Vec::new()
}

fn udafs_names(&self) -> Vec<String> {
Vec::new()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ pub trait ContextProvider {

/// Get configuration options
fn options(&self) -> &ConfigOptions;

fn udfs_names(&self) -> Vec<String>;
fn udafs_names(&self) -> Vec<String>;
fn udwfs_names(&self) -> Vec<String>;
}

/// SQL parser options
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2901,6 +2901,18 @@ impl ContextProvider for MockContextProvider {
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(EmptyTable::new(schema)))
}

fn udfs_names(&self) -> Vec<String> {
self.udfs.keys().cloned().collect()
}

fn udafs_names(&self) -> Vec<String> {
self.udafs.keys().cloned().collect()
}

fn udwfs_names(&self) -> Vec<String> {
Vec::new()
}
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'?
SELECT arrowtypeof(v1) from test;

# Scalar function
statement error Invalid function 'to_timestamps_second'
statement error Did you mean 'to_timestamp_seconds'?
SELECT to_TIMESTAMPS_second(v2) from test;

# Aggregate function
Expand Down

0 comments on commit f1f0965

Please sign in to comment.