Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement serialization for UDWF and UDAF in plan protobuf #6769

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ message WindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
BuiltInWindowFunction built_in_function = 2;
// udaf = 3
string udaf = 3;
string udwf = 9;
}
LogicalExprNode expr = 4;
repeated LogicalExprNode partition_by = 5;
Expand Down
24 changes: 24 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,36 @@ pub fn parse_expr(
window_frame,
)))
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
let udaf_function = registry.udaf(udaf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::AggregateUDF(
udaf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
let udwf_function = registry.udwf(udwf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::WindowUDF(
udwf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
}
}
ExprType::AggregateExpr(expr) => {
Expand Down
112 changes: 110 additions & 2 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,8 @@ mod roundtrip_tests {
Expr, LogicalPlan, Operator, TryCast, Volatility,
};
use datafusion_expr::{
create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
create_udaf, PartitionEvaluator, Signature, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunction, WindowUDF,
};
use prost::Message;
use std::collections::HashMap;
Expand Down Expand Up @@ -2786,12 +2787,119 @@ mod roundtrip_tests {
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
));

// 5. test with AggregateUDF
#[derive(Debug)]
struct DummyAggr {}

impl Accumulator for DummyAggr {
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}

fn update_batch(
&mut self,
_values: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn merge_batch(
&mut self,
_states: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

let dummy_agg = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"dummy_agg",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
DataType::Float64,
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(DummyAggr {}))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
));
ctx.register_udaf(dummy_agg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is great!

Now all we need is a test for the WindowUDF (aka WindowFunction::WindowUDF) and I think this PR is ready to go!


// 6. test with WindowUDF
#[derive(Clone, Debug)]
struct DummyWindow {}

impl PartitionEvaluator for DummyWindow {
fn uses_window_frame(&self) -> bool {
true
}

fn evaluate(
&mut self,
_values: &[ArrayRef],
_range: &std::ops::Range<usize>,
) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}
}

fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"dummy_udwf expects 1 argument, got {}: {:?}",
arg_types.len(),
arg_types
)));
}
Ok(Arc::new(arg_types[0].clone()))
}

fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(DummyWindow {}))
}

let dummy_window_udf = WindowUDF {
name: String::from("dummy_udwf"),
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
return_type: Arc::new(return_type),
partition_evaluator_factory: Arc::new(make_partition_evaluator),
};

let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame,
));

ctx.register_udwf(dummy_window_udf);

roundtrip_expr_test(test_expr1, ctx.clone());
roundtrip_expr_test(test_expr2, ctx.clone());
roundtrip_expr_test(test_expr3, ctx.clone());
roundtrip_expr_test(test_expr4, ctx);
roundtrip_expr_test(test_expr4, ctx.clone());
roundtrip_expr_test(test_expr5, ctx.clone());
roundtrip_expr_test(test_expr6, ctx);
}
}
18 changes: 8 additions & 10 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,17 +584,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
protobuf::BuiltInWindowFunction::from(fun).into(),
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584
WindowFunction::AggregateUDF(_) => {
return Err(Error::NotImplemented(
"UDAF as window function in proto".to_string(),
))
WindowFunction::AggregateUDF(aggr_udf) => {
protobuf::window_expr_node::WindowFunction::Udaf(
aggr_udf.name.clone(),
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/6733
WindowFunction::WindowUDF(_) => {
return Err(Error::NotImplemented(
"UDWF as window function in proto".to_string(),
))
WindowFunction::WindowUDF(window_udf) => {
protobuf::window_expr_node::WindowFunction::Udwf(
window_udf.name.clone(),
)
}
};
let arg_expr: Option<Box<Self>> = if !args.is_empty() {
Expand Down