-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement serialization for UDWF and UDAF in plan protobuf #6769
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1431,7 +1431,8 @@ mod roundtrip_tests { | |
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, | ||
}; | ||
use crate::logical_plan::LogicalExtensionCodec; | ||
use arrow::datatypes::{Fields, Schema, SchemaRef, UnionFields}; | ||
use arrow::array::{AsArray, Float64Array}; | ||
use arrow::datatypes::{Fields, Float64Type, Schema, SchemaRef, UnionFields}; | ||
use arrow::{ | ||
array::ArrayRef, | ||
datatypes::{ | ||
|
@@ -1460,7 +1461,8 @@ mod roundtrip_tests { | |
Expr, LogicalPlan, Operator, TryCast, Volatility, | ||
}; | ||
use datafusion_expr::{ | ||
create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, | ||
create_udaf, PartitionEvaluator, Signature, WindowFrame, WindowFrameBound, | ||
WindowFrameUnits, WindowFunction, WindowUDF, | ||
}; | ||
use prost::Message; | ||
use std::collections::HashMap; | ||
|
@@ -2786,12 +2788,155 @@ mod roundtrip_tests { | |
vec![col("col1")], | ||
vec![col("col1")], | ||
vec![col("col2")], | ||
row_number_frame.clone(), | ||
)); | ||
|
||
// 5. test with AggregateUDF | ||
#[derive(Debug)] | ||
struct Dummy {} | ||
|
||
impl Accumulator for Dummy { | ||
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> { | ||
Ok(vec![]) | ||
} | ||
|
||
fn update_batch( | ||
&mut self, | ||
_values: &[ArrayRef], | ||
) -> datafusion::error::Result<()> { | ||
Ok(()) | ||
} | ||
|
||
fn merge_batch( | ||
&mut self, | ||
_states: &[ArrayRef], | ||
) -> datafusion::error::Result<()> { | ||
Ok(()) | ||
} | ||
|
||
fn evaluate(&self) -> datafusion::error::Result<ScalarValue> { | ||
Ok(ScalarValue::Float64(None)) | ||
} | ||
|
||
fn size(&self) -> usize { | ||
std::mem::size_of_val(self) | ||
} | ||
} | ||
|
||
let dummy_agg = create_udaf( | ||
// the name; used to represent it in plan descriptions and in the registry, to use in SQL. | ||
"dummy_agg", | ||
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. | ||
DataType::Float64, | ||
// the return type; DataFusion expects this to match the type returned by `evaluate`. | ||
Arc::new(DataType::Float64), | ||
Volatility::Immutable, | ||
// This is the accumulator factory; DataFusion uses it to create new accumulators. | ||
Arc::new(|_| Ok(Box::new(Dummy {}))), | ||
// This is the description of the state. `state()` must match the types here. | ||
Arc::new(vec![DataType::Float64, DataType::UInt32]), | ||
); | ||
|
||
let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( | ||
WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), | ||
vec![col("col1")], | ||
vec![col("col1")], | ||
vec![col("col2")], | ||
row_number_frame.clone(), | ||
)); | ||
ctx.register_udaf(dummy_agg); | ||
|
||
// 6. test with WindowUDF | ||
#[derive(Clone, Debug)] | ||
struct MyPartitionEvaluator {} | ||
|
||
impl MyPartitionEvaluator { | ||
fn new() -> Self { | ||
Self {} | ||
} | ||
} | ||
|
||
/// Different evaluation methods are called depending on the various | ||
/// settings of WindowUDF. This example uses the simplest and most | ||
/// general, `evaluate`. See `PartitionEvaluator` for the other more | ||
/// advanced uses. | ||
impl PartitionEvaluator for MyPartitionEvaluator { | ||
/// Tell DataFusion the window function varies based on the value | ||
/// of the window frame. | ||
fn uses_window_frame(&self) -> bool { | ||
true | ||
} | ||
|
||
/// This function is called once per input row. | ||
/// | ||
/// `range`specifies which indexes of `values` should be | ||
/// considered for the calculation. | ||
/// | ||
/// Note this is the SLOWEST, but simplest, way to evaluate a | ||
/// window function. It is much faster to implement | ||
/// evaluate_all or evaluate_all_with_rank, if possible | ||
fn evaluate( | ||
&mut self, | ||
values: &[ArrayRef], | ||
range: &std::ops::Range<usize>, | ||
) -> Result<ScalarValue> { | ||
// Again, the input argument is an array of floating | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can just use |
||
// point numbers to calculate a moving average | ||
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>(); | ||
|
||
let range_len = range.end - range.start; | ||
|
||
// our smoothing function will average all the values in the | ||
let output = if range_len > 0 { | ||
let sum: f64 = | ||
arr.values().iter().skip(range.start).take(range_len).sum(); | ||
Some(sum / range_len as f64) | ||
} else { | ||
None | ||
}; | ||
|
||
Ok(ScalarValue::Float64(output)) | ||
} | ||
} | ||
|
||
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> { | ||
if arg_types.len() != 1 { | ||
return Err(DataFusionError::Plan(format!( | ||
"my_udwf expects 1 argument, got {}: {:?}", | ||
arg_types.len(), | ||
arg_types | ||
))); | ||
} | ||
Ok(Arc::new(arg_types[0].clone())) | ||
} | ||
|
||
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> { | ||
Ok(Box::new(MyPartitionEvaluator::new())) | ||
} | ||
|
||
let dummy_window_udf = WindowUDF { | ||
name: String::from("smooth_it"), | ||
// it will take 1 arguments -- the column to smooth | ||
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), | ||
return_type: Arc::new(return_type), | ||
partition_evaluator_factory: Arc::new(make_partition_evaluator), | ||
}; | ||
|
||
let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( | ||
WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), | ||
vec![col("col1")], | ||
vec![col("col1")], | ||
vec![col("col2")], | ||
row_number_frame, | ||
)); | ||
|
||
ctx.register_udwf(dummy_window_udf); | ||
|
||
roundtrip_expr_test(test_expr1, ctx.clone()); | ||
roundtrip_expr_test(test_expr2, ctx.clone()); | ||
roundtrip_expr_test(test_expr3, ctx.clone()); | ||
roundtrip_expr_test(test_expr4, ctx); | ||
roundtrip_expr_test(test_expr4, ctx.clone()); | ||
roundtrip_expr_test(test_expr5, ctx.clone()); | ||
roundtrip_expr_test(test_expr6, ctx); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THis is great!
Now all we need is a test for the
WindowUDF
(akaWindowFunction::WindowUDF
) and I think this PR is ready to go!