Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement serialization for UDWF and UDAF in plan protobuf #6769

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ message WindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
BuiltInWindowFunction built_in_function = 2;
// udaf = 3
string udaf = 3;
string udwf = 9;
}
LogicalExprNode expr = 4;
repeated LogicalExprNode partition_by = 5;
Expand Down
24 changes: 24 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,36 @@ pub fn parse_expr(
window_frame,
)))
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
let udaf_function = registry.udaf(udaf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::AggregateUDF(
udaf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
let udwf_function = registry.udwf(udwf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::WindowUDF(
udwf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
}
}
ExprType::AggregateExpr(expr) => {
Expand Down
151 changes: 148 additions & 3 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,8 @@ mod roundtrip_tests {
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
};
use crate::logical_plan::LogicalExtensionCodec;
use arrow::datatypes::{Fields, Schema, SchemaRef, UnionFields};
use arrow::array::{AsArray, Float64Array};
use arrow::datatypes::{Fields, Float64Type, Schema, SchemaRef, UnionFields};
use arrow::{
array::ArrayRef,
datatypes::{
Expand Down Expand Up @@ -1460,7 +1461,8 @@ mod roundtrip_tests {
Expr, LogicalPlan, Operator, TryCast, Volatility,
};
use datafusion_expr::{
create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
create_udaf, PartitionEvaluator, Signature, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunction, WindowUDF,
};
use prost::Message;
use std::collections::HashMap;
Expand Down Expand Up @@ -2786,12 +2788,155 @@ mod roundtrip_tests {
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
));

// 5. test with AggregateUDF
#[derive(Debug)]
struct Dummy {}

impl Accumulator for Dummy {
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}

fn update_batch(
&mut self,
_values: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn merge_batch(
&mut self,
_states: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

let dummy_agg = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"dummy_agg",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
DataType::Float64,
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(Dummy {}))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
));
ctx.register_udaf(dummy_agg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is great!

Now all we need is a test for the WindowUDF (aka WindowFunction::WindowUDF) and I think this PR is ready to go!


// 6. test with WindowUDF
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}

impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}

/// Different evaluation methods are called depending on the various
/// settings of WindowUDF. This example uses the simplest and most
/// general, `evaluate`. See `PartitionEvaluator` for the other more
/// advanced uses.
impl PartitionEvaluator for MyPartitionEvaluator {
/// Tell DataFusion the window function varies based on the value
/// of the window frame.
fn uses_window_frame(&self) -> bool {
true
}

/// This function is called once per input row.
///
/// `range`specifies which indexes of `values` should be
/// considered for the calculation.
///
/// Note this is the SLOWEST, but simplest, way to evaluate a
/// window function. It is much faster to implement
/// evaluate_all or evaluate_all_with_rank, if possible
fn evaluate(
&mut self,
values: &[ArrayRef],
range: &std::ops::Range<usize>,
) -> Result<ScalarValue> {
// Again, the input argument is an array of floating
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can just use todo!() here and in the other functions? These should not be called during the logical plan testing

// point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();

let range_len = range.end - range.start;

// our smoothing function will average all the values in the
let output = if range_len > 0 {
let sum: f64 =
arr.values().iter().skip(range.start).take(range_len).sum();
Some(sum / range_len as f64)
} else {
None
};

Ok(ScalarValue::Float64(output))
}
}

fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"my_udwf expects 1 argument, got {}: {:?}",
arg_types.len(),
arg_types
)));
}
Ok(Arc::new(arg_types[0].clone()))
}

fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}

let dummy_window_udf = WindowUDF {
name: String::from("smooth_it"),
// it will take 1 arguments -- the column to smooth
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
return_type: Arc::new(return_type),
partition_evaluator_factory: Arc::new(make_partition_evaluator),
};

let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame,
));

ctx.register_udwf(dummy_window_udf);

roundtrip_expr_test(test_expr1, ctx.clone());
roundtrip_expr_test(test_expr2, ctx.clone());
roundtrip_expr_test(test_expr3, ctx.clone());
roundtrip_expr_test(test_expr4, ctx);
roundtrip_expr_test(test_expr4, ctx.clone());
roundtrip_expr_test(test_expr5, ctx.clone());
roundtrip_expr_test(test_expr6, ctx);
}
}
18 changes: 8 additions & 10 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,17 +584,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
protobuf::BuiltInWindowFunction::from(fun).into(),
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584
WindowFunction::AggregateUDF(_) => {
return Err(Error::NotImplemented(
"UDAF as window function in proto".to_string(),
))
WindowFunction::AggregateUDF(aggr_udf) => {
protobuf::window_expr_node::WindowFunction::Udaf(
aggr_udf.name.clone(),
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/6733
WindowFunction::WindowUDF(_) => {
return Err(Error::NotImplemented(
"UDWF as window function in proto".to_string(),
))
WindowFunction::WindowUDF(window_udf) => {
protobuf::window_expr_node::WindowFunction::Udwf(
window_udf.name.clone(),
)
}
};
let arg_expr: Option<Box<Self>> = if !args.is_empty() {
Expand Down