Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix functions with Volatility::Volatile and parameters #13001

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use std::any::Any;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::array::as_string_array;
use arrow::compute::kernels::numeric::add;
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

/// Volatile UDF that should append a different value to each row
#[derive(Debug)]
struct AddIndexToStringVolatileScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
}

impl AddIndexToStringVolatileScalarUDF {
fn new() -> Self {
Self {
name: "add_index_to_string".to_string(),
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile),
return_type: DataType::Utf8,
}
}
}

impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("index_with_offset function does not accept arguments")
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
let answer = match &args[0] {
// When called with static arguments, the result is returned as an array.
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
let mut answer = vec![];
for index in 1..=number_rows {
// When calling a function with immutable arguments, the result is returned with ")".
// Example: SELECT add_index_to_string('const_value') FROM table;
answer.push(index.to_string() + ") " + value);
}
answer
}
// The result is returned as an array when called with dynamic arguments.
ColumnarValue::Array(array) => {
let string_array = as_string_array(array);
let mut counter = HashMap::<&str, u64>::new();
string_array
.iter()
.map(|value| {
let value = value.expect("Unexpected null");
let index = counter.get(value).unwrap_or(&0) + 1;
counter.insert(value, index);

// When calling a function with mutable arguments, the result is returned with ".".
// Example: SELECT add_index_to_string(table.value) FROM table;
index.to_string() + ". " + value
})
.collect()
}
_ => unimplemented!(),
};
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
}
}

#[tokio::test]
async fn volatile_scalar_udf_with_params() -> Result<()> {
{
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(StringArray::from(vec![
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2",
]))],
)?;
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();

ctx.register_udf(ScalarUDF::from(get_new_str_udf));

let result =
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters
.await?;
let expected = [
"+-----------+",
"| str |",
"+-----------+",
"| 1. test_1 |",
"| 2. test_1 |",
"| 3. test_1 |",
"| 1. test_2 |",
"| 2. test_2 |",
"| 4. test_1 |",
"| 3. test_2 |",
"+-----------+",
];
assert_batches_eq!(expected, &result);

let result =
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

.await?;
let expected = [
"+---------+",
"| str |",
"+---------+",
"| 1) test |",
"| 2) test |",
"| 3) test |",
"| 4) test |",
"| 5) test |",
"| 6) test |",
"| 7) test |",
"+---------+",
];
assert_batches_eq!(expected, &result);

let result =
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters
.await?;
let expected = [
"+---------------+",
"| str |",
"+---------------+",
"| 1) test_value |",
"+---------------+",
];
assert_batches_eq!(expected, &result);
}
{
Comment on lines +632 to +633
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: consider separating test cases into separate test functions, this would given them descriptive names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are very similar, so they don't break down into two parts well.

let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(StringArray::from(vec![
"test_1", "test_1", "test_1",
]))],
)?;
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new();

ctx.register_udf(ScalarUDF::from(get_new_str_udf));

let result =
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t")
.await?;
let expected = [
"+-----------+", //
"| str |", //
"+-----------+", //
"| 1. test_1 |", //
"| 2. test_1 |", //
"| 3. test_1 |", //
"+-----------+",
];
assert_batches_eq!(expected, &result);
}
Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
Expand Down
31 changes: 30 additions & 1 deletion datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
Expand Down Expand Up @@ -467,7 +478,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! -- as a follow on PR I think we should deprecate the other two functions (invoke_no_args and invoke) telling people to use invoke instead

Is this ok with you @jayzhan211 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a follow on PR I think we should deprecate the other two functions (invoke_no_args and invoke) telling people to use invoke instead

did you mean invoke_batch?

yes, it would be great to have only one invoke entry-point

not_impl_err!(
"Function {} does not implement invoke but called",
self.name()
)
}

/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
match args.is_empty() {
true => self.invoke_no_args(number_rows),
false => self.invoke(args),
}
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
Expand Down
5 changes: 1 addition & 4 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
.collect::<Result<Vec<_>>>()?;

// evaluate the function
let output = match self.args.is_empty() {
true => self.fun.invoke_no_args(batch.num_rows()),
false => self.fun.invoke(&inputs),
}?;
let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;

if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
Expand Down