From 5b95fa86e1fe36dd0cd3797436e1664626e6a9bc Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Thu, 7 Dec 2023 20:07:43 +0800 Subject: [PATCH] support inter leave node --- datafusion/proto/proto/datafusion.proto | 5 + datafusion/proto/src/generated/pbjson.rs | 104 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 10 +- datafusion/proto/src/physical_plan/mod.rs | 44 ++++++-- .../tests/cases/roundtrip_physical_plan.rs | 69 +++++++++--- 5 files changed, 206 insertions(+), 26 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 863e3c315c820..65dd23739fde9 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1161,6 +1161,7 @@ message PhysicalPlanNode { AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; } } @@ -1454,6 +1455,10 @@ message SymmetricHashJoinExecNode { JoinFilter filter = 8; } +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 74798ee8e94c4..47215118de63b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9205,6 +9205,97 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17887,6 +17978,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { struct_ser.serialize_field("symmetricHashJoin", v)?; } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } } } struct_ser.end() @@ -17935,6 +18029,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonSink", "symmetric_hash_join", "symmetricHashJoin", + "interleave", ]; #[allow(clippy::enum_variant_names)] @@ -17963,6 +18058,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Analyze, JsonSink, SymmetricHashJoin, + Interleave, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18008,6 +18104,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "analyze" => Ok(GeneratedField::Analyze), "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18196,6 +18293,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ae20913e3dd73..2f9d8a70321b8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1521,7 +1521,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26" )] pub physical_plan_type: ::core::option::Option, } @@ -1580,6 +1580,8 @@ pub mod physical_plan_node { JsonSink(::prost::alloc::boxed::Box), #[prost(message, tag = "25")] SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2038,6 +2040,12 @@ pub struct SymmetricHashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 74c8ec894ff2a..878a5bcb7f694 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -48,7 +48,7 @@ use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, @@ -545,7 +545,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -556,7 +556,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -634,7 +634,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -645,7 +645,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -693,6 +693,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -735,7 +746,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -782,7 +793,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -845,7 +856,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -856,7 +867,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -1463,6 +1474,21 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 287207bae5f66..2cc060dd759c6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -50,12 +50,14 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - functions, udaf, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -231,21 +233,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -798,3 +800,38 @@ fn roundtrip_sym_hash_join() -> Result<()> { } Ok(()) } + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(false, Arc::new(schema_left)); + let right = EmptyExec::new(false, Arc::new(schema_right)); + let mut inputs: Vec> = vec![]; + inputs.push(Arc::new(left)); + inputs.push(Arc::new(right)); + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_right))), + partition.clone(), + )?; + let mut inputs: Vec> = vec![]; + inputs.push(Arc::new(left)); + inputs.push(Arc::new(right)); + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +}