-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix functions with Volatility::Volatile and parameters #13001
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,11 @@ | |
// under the License. | ||
|
||
use std::any::Any; | ||
use std::collections::HashMap; | ||
use std::hash::{DefaultHasher, Hash, Hasher}; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::as_string_array; | ||
use arrow::compute::kernels::numeric::add; | ||
use arrow_array::builder::BooleanBuilder; | ||
use arrow_array::cast::AsArray; | ||
|
@@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
/// Volatile UDF that should append a different value to each row | ||
#[derive(Debug)] | ||
struct AddIndexToStringVolatileScalarUDF { | ||
name: String, | ||
signature: Signature, | ||
return_type: DataType, | ||
} | ||
|
||
impl AddIndexToStringVolatileScalarUDF { | ||
fn new() -> Self { | ||
Self { | ||
name: "add_index_to_string".to_string(), | ||
signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), | ||
return_type: DataType::Utf8, | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(self.return_type.clone()) | ||
} | ||
|
||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
not_impl_err!("index_with_offset function does not accept arguments") | ||
} | ||
|
||
fn invoke_batch( | ||
&self, | ||
args: &[ColumnarValue], | ||
number_rows: usize, | ||
) -> Result<ColumnarValue> { | ||
let answer = match &args[0] { | ||
// When called with static arguments, the result is returned as an array. | ||
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { | ||
let mut answer = vec![]; | ||
for index in 1..=number_rows { | ||
// When calling a function with immutable arguments, the result is returned with ")". | ||
// Example: SELECT add_index_to_string('const_value') FROM table; | ||
answer.push(index.to_string() + ") " + value); | ||
} | ||
answer | ||
} | ||
// The result is returned as an array when called with dynamic arguments. | ||
ColumnarValue::Array(array) => { | ||
let string_array = as_string_array(array); | ||
let mut counter = HashMap::<&str, u64>::new(); | ||
string_array | ||
.iter() | ||
.map(|value| { | ||
let value = value.expect("Unexpected null"); | ||
let index = counter.get(value).unwrap_or(&0) + 1; | ||
counter.insert(value, index); | ||
|
||
// When calling a function with mutable arguments, the result is returned with ".". | ||
// Example: SELECT add_index_to_string(table.value) FROM table; | ||
index.to_string() + ". " + value | ||
}) | ||
.collect() | ||
} | ||
_ => unimplemented!(), | ||
}; | ||
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn volatile_scalar_udf_with_params() -> Result<()> { | ||
{ | ||
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); | ||
|
||
let batch = RecordBatch::try_new( | ||
Arc::new(schema.clone()), | ||
vec![Arc::new(StringArray::from(vec![ | ||
"test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", | ||
]))], | ||
)?; | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_batch("t", batch)?; | ||
|
||
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); | ||
|
||
ctx.register_udf(ScalarUDF::from(get_new_str_udf)); | ||
|
||
let result = | ||
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters | ||
.await?; | ||
let expected = [ | ||
"+-----------+", | ||
"| str |", | ||
"+-----------+", | ||
"| 1. test_1 |", | ||
"| 2. test_1 |", | ||
"| 3. test_1 |", | ||
"| 1. test_2 |", | ||
"| 2. test_2 |", | ||
"| 4. test_1 |", | ||
"| 3. test_2 |", | ||
"+-----------+", | ||
]; | ||
assert_batches_eq!(expected, &result); | ||
|
||
let result = | ||
plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters | ||
.await?; | ||
let expected = [ | ||
"+---------+", | ||
"| str |", | ||
"+---------+", | ||
"| 1) test |", | ||
"| 2) test |", | ||
"| 3) test |", | ||
"| 4) test |", | ||
"| 5) test |", | ||
"| 6) test |", | ||
"| 7) test |", | ||
"+---------+", | ||
]; | ||
assert_batches_eq!(expected, &result); | ||
|
||
let result = | ||
plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters | ||
.await?; | ||
let expected = [ | ||
"+---------------+", | ||
"| str |", | ||
"+---------------+", | ||
"| 1) test_value |", | ||
"+---------------+", | ||
]; | ||
assert_batches_eq!(expected, &result); | ||
} | ||
{ | ||
Comment on lines
+632
to
+633
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: consider separating test cases into separate test functions, this would given them descriptive names There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests are very similar, so they don't break down into two parts well. |
||
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); | ||
|
||
let batch = RecordBatch::try_new( | ||
Arc::new(schema.clone()), | ||
vec![Arc::new(StringArray::from(vec![ | ||
"test_1", "test_1", "test_1", | ||
]))], | ||
)?; | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_batch("t", batch)?; | ||
|
||
let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); | ||
|
||
ctx.register_udf(ScalarUDF::from(get_new_str_udf)); | ||
|
||
let result = | ||
plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") | ||
.await?; | ||
let expected = [ | ||
"+-----------+", // | ||
"| str |", // | ||
"+-----------+", // | ||
"| 1. test_1 |", // | ||
"| 2. test_1 |", // | ||
"| 3. test_1 |", // | ||
"+-----------+", | ||
]; | ||
assert_batches_eq!(expected, &result); | ||
} | ||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct CastToI64UDF { | ||
signature: Signature, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,6 +201,17 @@ impl ScalarUDF { | |
self.inner.is_nullable(args, schema) | ||
} | ||
|
||
/// Invoke the function with `args` and number of rows, returning the appropriate result. | ||
/// | ||
/// See [`ScalarUDFImpl::invoke_batch`] for more details. | ||
pub fn invoke_batch( | ||
&self, | ||
args: &[ColumnarValue], | ||
number_rows: usize, | ||
) -> Result<ColumnarValue> { | ||
self.inner.invoke_batch(args, number_rows) | ||
} | ||
|
||
/// Invoke the function without `args` but number of rows, returning the appropriate result. | ||
/// | ||
/// See [`ScalarUDFImpl::invoke_no_args`] for more details. | ||
|
@@ -467,7 +478,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
/// to arrays, which will likely be simpler code, but be slower. | ||
/// | ||
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>; | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! -- as a follow on PR I think we should deprecate the other two functions ( Is this ok with you @jayzhan211 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
did you mean yes, it would be great to have only one invoke entry-point |
||
not_impl_err!( | ||
"Function {} does not implement invoke but called", | ||
self.name() | ||
) | ||
} | ||
|
||
/// Invoke the function with `args` and the number of rows, | ||
/// returning the appropriate result. | ||
fn invoke_batch( | ||
&self, | ||
args: &[ColumnarValue], | ||
number_rows: usize, | ||
) -> Result<ColumnarValue> { | ||
match args.is_empty() { | ||
true => self.invoke_no_args(number_rows), | ||
false => self.invoke(args), | ||
} | ||
} | ||
|
||
/// Invoke the function without `args`, instead the number of rows are provided, | ||
/// returning the appropriate result. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice