diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 416227f70de9..4868017aa7ef 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -61,6 +61,7 @@ message LogicalPlanNode { UnnestNode unnest = 30; RecursiveQueryNode recursive_query = 31; CteWorkTableScanNode cte_work_table_scan = 32; + DmlNode dml = 33; } } @@ -264,6 +265,22 @@ message CopyToNode { repeated string partition_by = 7; } +message DmlNode{ + enum Type { + UPDATE = 0; + DELETE = 1; + CTAS = 2; + INSERT_APPEND = 3; + INSERT_OVERWRITE = 4; + INSERT_REPLACE = 5; + + } + Type dml_type = 1; + LogicalPlanNode input = 2; + TableReference table_name = 3; + datafusion_common.DfSchema schema = 4; +} + message UnnestNode { LogicalPlanNode input = 1; repeated datafusion_common.Column exec_columns = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cffb63018676..2850d350f6d8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4747,6 +4747,235 @@ impl<'de> serde::Deserialize<'de> for DistinctOnNode { deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for DmlNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.dml_type != 0 { + len += 1; + } + if self.input.is_some() { + len += 1; + } + if self.table_name.is_some() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DmlNode", len)?; + if self.dml_type != 0 { + let v = dml_node::Type::try_from(self.dml_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.dml_type)))?; + struct_ser.serialize_field("dmlType", &v)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DmlNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "dml_type", + "dmlType", + "input", + "table_name", + "tableName", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DmlType, + Input, + TableName, + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dmlType" | "dml_type" => Ok(GeneratedField::DmlType), + "input" => Ok(GeneratedField::Input), + "tableName" | "table_name" => Ok(GeneratedField::TableName), + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DmlNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.DmlNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut dml_type__ = None; + let mut input__ = None; + let mut table_name__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DmlType => { + if dml_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dmlType")); + } + dml_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::TableName => { + if table_name__.is_some() { + return Err(serde::de::Error::duplicate_field("tableName")); + } + table_name__ = map_.next_value()?; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(DmlNode { + dml_type: dml_type__.unwrap_or_default(), + input: input__, + table_name: table_name__, + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.DmlNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for dml_node::Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Update => "UPDATE", + Self::Delete => "DELETE", + Self::Ctas => "CTAS", + Self::InsertAppend => "INSERT_APPEND", + Self::InsertOverwrite => "INSERT_OVERWRITE", + Self::InsertReplace => "INSERT_REPLACE", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for dml_node::Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "UPDATE", + "DELETE", + "CTAS", + "INSERT_APPEND", + "INSERT_OVERWRITE", + "INSERT_REPLACE", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = dml_node::Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "UPDATE" => Ok(dml_node::Type::Update), + "DELETE" => Ok(dml_node::Type::Delete), + "CTAS" => Ok(dml_node::Type::Ctas), + "INSERT_APPEND" => Ok(dml_node::Type::InsertAppend), + "INSERT_OVERWRITE" => Ok(dml_node::Type::InsertOverwrite), + "INSERT_REPLACE" => Ok(dml_node::Type::InsertReplace), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -10639,6 +10868,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::CteWorkTableScan(v) => { struct_ser.serialize_field("cteWorkTableScan", v)?; } + logical_plan_node::LogicalPlanType::Dml(v) => { + struct_ser.serialize_field("dml", v)?; + } } } struct_ser.end() @@ -10697,6 +10929,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "recursiveQuery", "cte_work_table_scan", "cteWorkTableScan", + "dml", ]; #[allow(clippy::enum_variant_names)] @@ -10732,6 +10965,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { Unnest, RecursiveQuery, CteWorkTableScan, + Dml, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10784,6 +11018,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "unnest" => Ok(GeneratedField::Unnest), "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), + "dml" => Ok(GeneratedField::Dml), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11021,6 +11256,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("cteWorkTableScan")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CteWorkTableScan) +; + } + GeneratedField::Dml => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dml")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Dml) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d2fda5dc8892..a25e21739f38 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" )] pub logical_plan_type: ::core::option::Option, } @@ -75,6 +75,8 @@ pub mod logical_plan_node { RecursiveQuery(::prost::alloc::boxed::Box), #[prost(message, tag = "32")] CteWorkTableScan(super::CteWorkTableScanNode), + #[prost(message, tag = "33")] + Dml(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -400,6 +402,68 @@ pub struct CopyToNode { pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DmlNode { + #[prost(enumeration = "dml_node::Type", tag = "1")] + pub dml_type: i32, + #[prost(message, optional, boxed, tag = "2")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "3")] + pub table_name: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub schema: ::core::option::Option, +} +/// Nested message and enum types in `DmlNode`. +pub mod dml_node { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum Type { + Update = 0, + Delete = 1, + Ctas = 2, + InsertAppend = 3, + InsertOverwrite = 4, + InsertReplace = 5, + } + impl Type { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Update => "UPDATE", + Self::Delete => "DELETE", + Self::Ctas => "CTAS", + Self::InsertAppend => "INSERT_APPEND", + Self::InsertOverwrite => "INSERT_OVERWRITE", + Self::InsertReplace => "INSERT_REPLACE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UPDATE" => Some(Self::Update), + "DELETE" => Some(Self::Delete), + "CTAS" => Some(Self::Ctas), + "INSERT_APPEND" => Some(Self::InsertAppend), + "INSERT_OVERWRITE" => Some(Self::InsertOverwrite), + "INSERT_REPLACE" => Some(Self::InsertReplace), + _ => None, + } + } + } +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnnestNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 6ab3e0c9096c..e04a89a03dae 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,9 +22,9 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; +use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; -use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -33,6 +33,7 @@ use datafusion_expr::{ JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use datafusion_expr::{ExprFunctionExt, WriteOp}; use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; use crate::protobuf::plan_type::PlanTypeEnum::{ @@ -217,6 +218,21 @@ impl From for JoinConstraint { } } +impl From for WriteOp { + fn from(t: protobuf::dml_node::Type) -> Self { + match t { + protobuf::dml_node::Type::Update => WriteOp::Update, + protobuf::dml_node::Type::Delete => WriteOp::Delete, + protobuf::dml_node::Type::InsertAppend => WriteOp::Insert(InsertOp::Append), + protobuf::dml_node::Type::InsertOverwrite => { + WriteOp::Insert(InsertOp::Overwrite) + } + protobuf::dml_node::Type::InsertReplace => WriteOp::Insert(InsertOp::Replace), + protobuf::dml_node::Type::Ctas => WriteOp::Ctas, + } + } +} + pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index addafeb7629d..53b683bac66a 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, - CustomTableScanNode, SortExprNodeCollection, + dml_node, ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, + CustomTableScanNode, DmlNode, SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -70,7 +70,8 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, FetchType, RecursiveQuery, SkipType, Unnest, + AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, + Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -938,6 +939,14 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } + LogicalPlanType::Dml(dml_node) => Ok(LogicalPlan::Dml( + datafusion::logical_expr::DmlStatement::new( + from_table_reference(dml_node.table_name.as_ref(), "DML ")?, + Arc::new(convert_required!(dml_node.schema)?), + dml_node.dml_type().into(), + Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), + ), + )), } } @@ -1647,9 +1656,25 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Statement(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Statement", )), - LogicalPlan::Dml(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Dml", - )), + LogicalPlan::Dml(DmlStatement { + table_name, + table_schema, + op, + input, + .. + }) => { + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + let dml_type: dml_node::Type = op.into(); + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Dml(Box::new(DmlNode { + input: Some(Box::new(input)), + schema: Some(table_schema.try_into()?), + table_name: Some(table_name.clone().into()), + dml_type: dml_type.into(), + }))), + }) + } LogicalPlan::Copy(dml::CopyTo { input, output_url, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index caceb3db164c..6d1d4f30610c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -20,10 +20,12 @@ //! processes. use datafusion_common::{TableReference, UnnestOptions}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, ScalarFunction, Unnest, }; +use datafusion_expr::WriteOp; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -686,3 +688,18 @@ impl From for protobuf::JoinConstraint { } } } + +impl From<&WriteOp> for protobuf::dml_node::Type { + fn from(t: &WriteOp) -> Self { + match t { + WriteOp::Insert(InsertOp::Append) => protobuf::dml_node::Type::InsertAppend, + WriteOp::Insert(InsertOp::Overwrite) => { + protobuf::dml_node::Type::InsertOverwrite + } + WriteOp::Insert(InsertOp::Replace) => protobuf::dml_node::Type::InsertReplace, + WriteOp::Delete => protobuf::dml_node::Type::Delete, + WriteOp::Update => protobuf::dml_node::Type::Update, + WriteOp::Ctas => protobuf::dml_node::Type::Ctas, + } + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 687406c7db41..d7620e65c41e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -375,6 +375,44 @@ async fn roundtrip_logical_plan_sort() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_dml() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + let queries = [ + "INSERT INTO T1 VALUES (1, null)", + "INSERT OVERWRITE T1 VALUES (1, null)", + "REPLACE INTO T1 VALUES (1, null)", + "INSERT OR REPLACE INTO T1 VALUES (1, null)", + "DELETE FROM T1", + "UPDATE T1 SET a = 1", + "CREATE TABLE T2 AS SELECT * FROM T1", + ]; + for query in queries { + let plan = ctx.sql(query).await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!( + format!("{plan}"), + format!("{logical_round_trip}"), + "failed query roundtrip: {}", + query + ); + } + + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new();