Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement serialization for UDWF and UDAF in plan protobuf #6769

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,8 @@ message WindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
BuiltInWindowFunction built_in_function = 2;
// udaf = 3
string udaf = 3;
string udwf = 9;
}
LogicalExprNode expr = 4;
repeated LogicalExprNode partition_by = 5;
Expand Down
24 changes: 24 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,36 @@ pub fn parse_expr(
window_frame,
)))
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
let udaf_function = registry.udaf(udaf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::AggregateUDF(
udaf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
let udwf_function = registry.udwf(udwf_name)?;
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
datafusion_expr::window_function::WindowFunction::WindowUDF(
udwf_function,
),
args,
partition_by,
order_by,
window_frame,
)))
}
}
}
ExprType::AggregateExpr(expr) => {
Expand Down
56 changes: 55 additions & 1 deletion datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2786,12 +2786,66 @@ mod roundtrip_tests {
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
));
#[derive(Debug)]
struct Dummy {}

impl Accumulator for Dummy {
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}

fn update_batch(
&mut self,
_values: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn merge_batch(
&mut self,
_states: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

let dummy_agg = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"dummy_agg",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
DataType::Float64,
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(Dummy {}))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
row_number_frame,
));
ctx.register_udaf(dummy_agg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is great!

Now all we need is a test for the WindowUDF (aka WindowFunction::WindowUDF) and I think this PR is ready to go!


roundtrip_expr_test(test_expr1, ctx.clone());
roundtrip_expr_test(test_expr2, ctx.clone());
roundtrip_expr_test(test_expr3, ctx.clone());
roundtrip_expr_test(test_expr4, ctx);
roundtrip_expr_test(test_expr4, ctx.clone());
roundtrip_expr_test(test_expr5, ctx);
}
}
16 changes: 8 additions & 8 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,16 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can remove the TODO comments as well

WindowFunction::AggregateUDF(_) => {
return Err(Error::NotImplemented(
"UDAF as window function in proto".to_string(),
))
WindowFunction::AggregateUDF(aggr_udf) => {
protobuf::window_expr_node::WindowFunction::Udaf(
aggr_udf.name.clone(),
)
}
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/6733
WindowFunction::WindowUDF(_) => {
return Err(Error::NotImplemented(
"UDWF as window function in proto".to_string(),
))
WindowFunction::WindowUDF(window_udf) => {
protobuf::window_expr_node::WindowFunction::Udwf(
window_udf.name.clone(),
)
}
};
let arg_expr: Option<Box<Self>> = if !args.is_empty() {
Expand Down