diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs index d5de88d9339e..7f9c2fa2c06f 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs @@ -118,6 +118,26 @@ impl NestedLoopJoinExec { metrics: Default::default(), }) } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Filters applied before join output + pub fn filter(&self) -> Option<&JoinFilter> { + self.filter.as_ref() + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } } impl DisplayAs for NestedLoopJoinExec { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 528c67557093..89bca57cf306 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1059,6 +1059,7 @@ message PhysicalPlanNode { UnionExecNode union = 19; ExplainExecNode explain = 20; SortPreservingMergeExecNode sort_preserving_merge = 21; + NestedLoopJoinExecNode nested_loop_join = 22; } } @@ -1380,6 +1381,13 @@ message SortPreservingMergeExecNode { int64 fetch = 3; } +message NestedLoopJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + JoinType join_type = 3; + JoinFilter filter = 4; +} + message CoalesceBatchesExecNode { PhysicalPlanNode input = 1; uint32 target_batch_size = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d6a770159b04..590b462ad815 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12010,6 +12010,151 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NestedLoopJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = JoinType::from_i32(self.join_type) + .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "join_type", + "joinType", + "filter", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + JoinType, + Filter, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "filter" => Ok(GeneratedField::Filter), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NestedLoopJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NestedLoopJoinExecNode") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut filter__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map.next_value()?; + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map.next_value()?; + } + } + } + Ok(NestedLoopJoinExecNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + filter: filter__, + }) + } + } + deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for Not { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -15335,6 +15480,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { struct_ser.serialize_field("sortPreservingMerge", v)?; } + physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { + struct_ser.serialize_field("nestedLoopJoin", v)?; + } } } struct_ser.end() @@ -15376,6 +15524,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "explain", "sort_preserving_merge", "sortPreservingMerge", + "nested_loop_join", + "nestedLoopJoin", ]; #[allow(clippy::enum_variant_names)] @@ -15400,6 +15550,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Union, Explain, SortPreservingMerge, + NestedLoopJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15441,6 +15592,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "union" => Ok(GeneratedField::Union), "explain" => Ok(GeneratedField::Explain), "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), + "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15601,6 +15753,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); } physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) +; + } + GeneratedField::NestedLoopJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); + } + physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4e91fbab19fb..251760f0902a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1394,7 +1394,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22" )] pub physical_plan_type: ::core::option::Option, } @@ -1445,6 +1445,8 @@ pub mod physical_plan_node { SortPreservingMerge( ::prost::alloc::boxed::Box, ), + #[prost(message, tag = "22")] + NestedLoopJoin(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1950,6 +1952,18 @@ pub struct SortPreservingMergeExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NestedLoopJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(enumeration = "JoinType", tag = "3")] + pub join_type: i32, + #[prost(message, optional, tag = "4")] + pub filter: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct CoalesceBatchesExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 7bbbe135680b..b5b4aeb2dabc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -35,7 +35,7 @@ use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::joins::CrossJoinExec; +use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; @@ -716,6 +716,61 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(extension_node) } + PhysicalPlanType::NestedLoopJoin(join) => { + let left: Arc = + into_physical_plan(&join.left, registry, runtime, extension_codec)?; + let right: Arc = + into_physical_plan(&join.right, registry, runtime, extension_codec)?; + let join_type = + protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| { + proto_error(format!( + "Received a NestedLoopJoinExecNode message with unknown JoinType {}", + join.join_type + )) + })?; + let filter = join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::from_i32(i.side) + .ok_or_else(|| proto_error(format!( + "Received a NestedLoopJoinExecNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex{ + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, + right, + filter, + &join_type.into(), + )?)) + } } } @@ -1155,6 +1210,52 @@ impl AsExecutionPlan for PhysicalPlanNode { }), )), }) + } else if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::NestedLoopJoin(Box::new( + protobuf::NestedLoopJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + join_type: join_type.into(), + filter, + }, + ))), + }) } else { let mut buf: Vec = vec![]; match extension_codec.try_encode(plan_clone.clone(), &mut buf) { @@ -1297,7 +1398,7 @@ mod roundtrip_tests { expressions::{binary, col, lit, NotExpr}, expressions::{Avg, Column, DistinctCount, PhysicalSortExpr}, filter::FilterExec, - joins::{HashJoinExec, PartitionMode}, + joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, sorts::sort::SortExec, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, @@ -1433,6 +1534,34 @@ mod roundtrip_tests { Ok(()) } + #[test] + fn roundtrip_nested_loop_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(EmptyExec::new(false, schema_left.clone())), + Arc::new(EmptyExec::new(false, schema_right.clone())), + None, + join_type, + )?))?; + } + Ok(()) + } + #[test] fn rountrip_aggregate() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false);