From 92470226fd30fc9ac410513addc18c1073c26cca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Jan 2023 02:09:50 -0700 Subject: [PATCH] Update substrait create to depend on version of DataFusion in the repo (#4879) --- Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 5 +- datafusion/substrait/src/consumer.rs | 374 +++++++++++++++--------- datafusion/substrait/src/producer.rs | 237 ++++++++++----- datafusion/substrait/src/serializer.rs | 15 +- datafusion/substrait/tests/roundtrip.rs | 74 +++-- datafusion/substrait/tests/serialize.rs | 6 +- 7 files changed, 448 insertions(+), 265 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ab431ea316a..e67ec42a9fe1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [workspace] -exclude = ["datafusion-cli"] +exclude = ["datafusion-cli", "datafusion/substrait"] members = [ "datafusion/common", "datafusion/core", diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index b8c0e56d2566..091b3a565b44 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -17,12 +17,13 @@ [package] name = "datafusion-substrait" -version = "0.1.0" +version = "16.0.0" edition = "2021" +rust-version = "1.62" [dependencies] async-recursion = "1.0" -datafusion = "13.0" +datafusion = { version = "16.0.0", path = "../core" } prost = "0.9" prost-types = "0.9" substrait = "0.2" diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs index c747a30a6bec..2f0e88969656 100644 --- a/datafusion/substrait/src/consumer.rs +++ b/datafusion/substrait/src/consumer.rs @@ -17,35 +17,34 @@ use async_recursion::async_recursion; use datafusion::common::{DFField, DFSchema, DFSchemaRef}; -use datafusion::logical_expr::{LogicalPlan, aggregate_function}; -use datafusion::logical_plan::build_join_schema; +use datafusion::logical_expr::build_join_schema; +use datafusion::logical_expr::expr; +use datafusion::logical_expr::{ + aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator, +}; use datafusion::prelude::JoinType; +use datafusion::sql::TableReference; use datafusion::{ error::{DataFusionError, Result}, - logical_plan::{Expr, Operator}, optimizer::utils::split_conjunction, prelude::{Column, DataFrame, SessionContext}, scalar::ScalarValue, }; - -use datafusion::sql::TableReference; use substrait::protobuf::{ aggregate_function::AggregationInvocation, expression::{ - field_reference::ReferenceType::DirectReference, - literal::LiteralType, - MaskExpression, - reference_segment::ReferenceType::StructField, - RexType, + field_reference::ReferenceType::DirectReference, literal::LiteralType, + reference_segment::ReferenceType::StructField, MaskExpression, RexType, }, extensions::simple_extension_declaration::MappingType, function_argument::ArgType, read_rel::ReadType, rel::RelType, - sort_field::{SortKind::*, SortDirection}, + sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, Plan, Rel, }; +use datafusion::logical_expr::expr::Sort; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -65,8 +64,6 @@ pub fn name_to_op(name: &str) -> Result { "mod" => Ok(Operator::Modulo), "and" => Ok(Operator::And), "or" => Ok(Operator::Or), - "like" => Ok(Operator::Like), - "not_like" => Ok(Operator::NotLike), "is_distinct_from" => Ok(Operator::IsDistinctFrom), "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), "regex_match" => Ok(Operator::RegexMatch), @@ -87,18 +84,27 @@ pub fn name_to_op(name: &str) -> Result { } /// Convert Substrait Plan to DataFusion DataFrame -pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Result> { +pub async fn from_substrait_plan( + ctx: &mut SessionContext, + plan: &Plan, +) -> Result { // Register function extension - let function_extension = plan.extensions + let function_extension = plan + .extensions .iter() .map(|e| match &e.mapping_type { - Some(ext) => { - match ext { - MappingType::ExtensionFunction(ext_f) => Ok((ext_f.function_anchor, &ext_f.name)), - _ => Err(DataFusionError::NotImplemented(format!("Extension type not supported: {:?}", ext))) + Some(ext) => match ext { + MappingType::ExtensionFunction(ext_f) => { + Ok((ext_f.function_anchor, &ext_f.name)) } - } - None => Err(DataFusionError::NotImplemented("Cannot parse empty extension".to_string())) + _ => Err(DataFusionError::NotImplemented(format!( + "Extension type not supported: {:?}", + ext + ))), + }, + None => Err(DataFusionError::NotImplemented( + "Cannot parse empty extension".to_string(), + )), }) .collect::>>()?; // Parse relations @@ -107,15 +113,14 @@ pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Resul match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { substrait::protobuf::plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, &rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &function_extension).await?) }, substrait::protobuf::plan_rel::RelType::Root(root) => { - Ok(from_substrait_rel(ctx, &root.input.as_ref().unwrap(), &function_extension).await?) + Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) } }, None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string())) } - }, _ => Err(DataFusionError::NotImplemented(format!( "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", @@ -126,14 +131,18 @@ pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> Resul /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] -pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: &HashMap) -> Result> { +pub async fn from_substrait_rel( + ctx: &mut SessionContext, + rel: &Rel, + extensions: &HashMap, +) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let input = from_substrait_rel(ctx, input, extensions).await?; let mut exprs: Vec = vec![]; for e in &p.expressions { - let x = from_substrait_rex(e, &input.schema(), extensions).await?; + let x = from_substrait_rex(e, input.schema(), extensions).await?; exprs.push(x.as_ref().clone()); } input.select(exprs) @@ -147,7 +156,8 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: if let Some(input) = filter.input.as_ref() { let input = from_substrait_rel(ctx, input, extensions).await?; if let Some(condition) = filter.condition.as_ref() { - let expr = from_substrait_rex(condition, &input.schema(), extensions).await?; + let expr = + from_substrait_rex(condition, input.schema(), extensions).await?; input.filter(expr.as_ref().clone()) } else { Err(DataFusionError::NotImplemented( @@ -177,7 +187,12 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: let input = from_substrait_rel(ctx, input, extensions).await?; let mut sorts: Vec = vec![]; for s in &sort.sorts { - let expr = from_substrait_rex(&s.expr.as_ref().unwrap(), &input.schema(), extensions).await?; + let expr = from_substrait_rex( + s.expr.as_ref().unwrap(), + input.schema(), + extensions, + ) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -189,32 +204,25 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: SortDirection::AscNullsLast => Ok((true, false)), SortDirection::DescNullsFirst => Ok((false, true)), SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => { - Err(DataFusionError::NotImplemented( - "Sort with direction clustered is not yet supported".to_string(), - )) - }, - SortDirection::Unspecified => { - Err(DataFusionError::NotImplemented( - "Unspecified sort direction is invalid".to_string(), - )) - } + SortDirection::Clustered => + Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet supported".to_string())) + , + SortDirection::Unspecified => + Err(DataFusionError::NotImplemented("Unspecified sort direction is invalid".to_string())) } } ComparisonFunctionReference(_) => { - Err(DataFusionError::NotImplemented( - "Sort using comparison function reference is not supported".to_string(), - )) + Err(DataFusionError::NotImplemented("Sort using comparison function reference is not supported".to_string())) }, }, - None => { - Err(DataFusionError::NotImplemented( - "Sort without sort kind is invalid".to_string(), - )) - }, + None => Err(DataFusionError::NotImplemented("Sort without sort kind is invalid".to_string())) }; let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort { expr: Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first }); + sorts.push(Expr::Sort(Sort { + expr: Box::new(expr.as_ref().clone()), + asc, + nulls_first, + })); } input.sort(sorts) } else { @@ -230,35 +238,55 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: let mut aggr_expr = vec![]; let groupings = match agg.groupings.len() { - 1 => { Ok(&agg.groupings[0]) }, - _ => { - Err(DataFusionError::NotImplemented( - "Aggregate with multiple grouping sets is not supported".to_string(), - )) - } + 1 => Ok(&agg.groupings[0]), + _ => Err(DataFusionError::NotImplemented( + "Aggregate with multiple grouping sets is not supported" + .to_string(), + )), }; for e in &groupings?.grouping_expressions { - let x = from_substrait_rex(&e, &input.schema(), extensions).await?; + let x = from_substrait_rex(e, input.schema(), extensions).await?; group_expr.push(x.as_ref().clone()); } for m in &agg.measures { let filter = match &m.filter { - Some(fil) => Some(Box::new(from_substrait_rex(fil, &input.schema(), extensions).await?.as_ref().clone())), - None => None + Some(fil) => Some(Box::new( + from_substrait_rex(fil, input.schema(), extensions) + .await? + .as_ref() + .clone(), + )), + None => None, }; let agg_func = match &m.measure { Some(f) => { - let distinct = match f.invocation { - _ if f.invocation == AggregationInvocation::Distinct as i32 => true, - _ if f.invocation == AggregationInvocation::All as i32 => false, - _ => false + let distinct = match f.invocation { + _ if f.invocation + == AggregationInvocation::Distinct as i32 => + { + true + } + _ if f.invocation + == AggregationInvocation::All as i32 => + { + false + } + _ => false, }; - from_substrait_agg_func(&f, &input.schema(), extensions, filter, distinct).await - }, + from_substrait_agg_func( + f, + input.schema(), + extensions, + filter, + distinct, + ) + .await + } None => Err(DataFusionError::NotImplemented( - "Aggregate without aggregate function is not supported".to_string(), + "Aggregate without aggregate function is not supported" + .to_string(), )), }; aggr_expr.push(agg_func?.as_ref().clone()); @@ -272,41 +300,50 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: } } Some(RelType::Join(join)) => { - let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), extensions).await?; - let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), extensions).await?; + let left = + from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?; + let right = + from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?; let join_type = match join.r#type { 1 => JoinType::Inner, 2 => JoinType::Left, 3 => JoinType::Right, 4 => JoinType::Full, - 5 => JoinType::Anti, - 6 => JoinType::Semi, - _ => return Err(DataFusionError::Internal("invalid join type".to_string())), + 5 => JoinType::LeftAnti, + 6 => JoinType::LeftSemi, + _ => { + return Err(DataFusionError::Internal( + "invalid join type".to_string(), + )) + } }; - let mut predicates = vec![]; - let schema = build_join_schema(&left.schema(), &right.schema(), &JoinType::Inner)?; - let on = from_substrait_rex(&join.expression.as_ref().unwrap(), &schema, extensions).await?; - split_conjunction(&on, &mut predicates); + let schema = + build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + let on = from_substrait_rex( + join.expression.as_ref().unwrap(), + &schema, + extensions, + ) + .await?; + let predicates = split_conjunction(&on); let pairs = predicates .iter() .map(|p| match p { - Expr::BinaryExpr { + Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right, - } => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => Ok((l.flat_name(), r.flat_name())), - _ => { - return Err(DataFusionError::Internal( - "invalid join condition".to_string(), - )) + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + Ok((l.flat_name(), r.flat_name())) } - }, - _ => { - return Err(DataFusionError::Internal( + _ => Err(DataFusionError::Internal( "invalid join condition".to_string(), - )) - } + )), + }, + _ => Err(DataFusionError::Internal( + "invalid join condition".to_string(), + )), }) .collect::>>()?; let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| l.as_str()).collect(); @@ -334,7 +371,7 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: table: &nt.names[2], }, }; - let t = ctx.table(table_reference)?; + let t = ctx.table(table_reference).await?; match &read.projection { Some(MaskExpression { select, .. }) => match &select.as_ref() { Some(projection) => { @@ -343,19 +380,23 @@ pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, extensions: .iter() .map(|item| item.field as usize) .collect(); - match t.to_logical_plan()? { + match t.into_optimized_plan()? { LogicalPlan::TableScan(scan) => { - let mut scan = scan.clone(); let fields: Vec = column_indices .iter() .map(|i| scan.projected_schema.field(*i).clone()) .collect(); + // clippy thinks this clone is redundant but it is not + #[allow(clippy::redundant_clone)] + let mut scan = scan.clone(); scan.projection = Some(column_indices); - scan.projected_schema = DFSchemaRef::new( - DFSchema::new_with_metadata(fields, HashMap::new())?, - ); + scan.projected_schema = + DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + HashMap::new(), + )?); let plan = LogicalPlan::TableScan(scan); - Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) + Ok(DataFrame::new(ctx.state(), plan)) } _ => Err(DataFusionError::Internal( "unexpected plan for table".to_string(), @@ -384,60 +425,62 @@ pub async fn from_substrait_agg_func( input_schema: &DFSchema, extensions: &HashMap, filter: Option>, - distinct: bool + distinct: bool, ) -> Result> { let mut args: Vec = vec![]; for arg in &f.arguments { let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => from_substrait_rex(e, input_schema, extensions).await, + Some(ArgType::Value(e)) => { + from_substrait_rex(e, input_schema, extensions).await + } _ => Err(DataFusionError::NotImplemented( - "Aggregated function argument non-Value type not supported".to_string(), - )) + "Aggregated function argument non-Value type not supported".to_string(), + )), }; args.push(arg_expr?.as_ref().clone()); } let fun = match extensions.get(&f.function_reference) { - Some(function_name) => aggregate_function::AggregateFunction::from_str(function_name), + Some(function_name) => { + aggregate_function::AggregateFunction::from_str(function_name) + } None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function anchor = {:?}", - f.function_reference - ) - )) + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ))), }; - Ok( - Arc::new( - Expr::AggregateFunction { - fun: fun.unwrap(), - args: args, - distinct: distinct, - filter: filter - } - ) - ) + Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { + fun: fun.unwrap(), + args, + distinct, + filter, + }))) } /// Convert Substrait Rex to DataFusion Expr #[async_recursion] -pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensions: &HashMap) -> Result> { +pub async fn from_substrait_rex( + e: &Expression, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { match &e.rex_type { Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { Some(StructField(x)) => match &x.child.as_ref() { Some(_) => Err(DataFusionError::NotImplemented( - "Direct reference StructField with child is not supported".to_string(), + "Direct reference StructField with child is not supported" + .to_string(), )), None => Ok(Arc::new(Expr::Column(Column { relation: None, - name: input_schema - .field(x.field as usize) - .name() - .to_string(), + name: input_schema.field(x.field as usize).name().to_string(), }))), }, _ => Err(DataFusionError::NotImplemented( - "Direct reference with types other than StructField is not supported".to_string(), + "Direct reference with types other than StructField is not supported" + .to_string(), )), }, _ => Err(DataFusionError::NotImplemented( @@ -453,45 +496,84 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi if i == 0 { // Check if the first element is type base expression if if_expr.then.is_none() { - expr = Some(Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone())); + expr = Some(Box::new( + from_substrait_rex( + if_expr.r#if.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + )); continue; } } - when_then_expr.push( - ( - Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()), - Box::new(from_substrait_rex(&if_expr.then.as_ref().unwrap(), input_schema, extensions).await?.as_ref().clone()) + when_then_expr.push(( + Box::new( + from_substrait_rex( + if_expr.r#if.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), ), - ); + Box::new( + from_substrait_rex( + if_expr.then.as_ref().unwrap(), + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + )); } // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(&e, input_schema, extensions).await?.as_ref().clone(), - )), - None => None + from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(), + )), + None => None, }; - Ok(Arc::new(Expr::Case { expr: expr, when_then_expr: when_then_expr, else_expr: else_expr })) - }, + Ok(Arc::new(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }))) + } Some(RexType::ScalarFunction(f)) => { assert!(f.arguments.len() == 2); let op = match extensions.get(&f.function_reference) { - Some(fname) => name_to_op(fname), - None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", - f.function_reference - ) - )) + Some(fname) => name_to_op(fname), + None => Err(DataFusionError::NotImplemented(format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + ))), }; match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) { (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { - Ok(Arc::new(Expr::BinaryExpr { - left: Box::new(from_substrait_rex(l, input_schema, extensions).await?.as_ref().clone()), + Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new( + from_substrait_rex(l, input_schema, extensions) + .await? + .as_ref() + .clone(), + ), op: op?, right: Box::new( - from_substrait_rex(r, input_schema, extensions).await?.as_ref().clone(), + from_substrait_rex(r, input_schema, extensions) + .await? + .as_ref() + .clone(), ), - })) + }))) } (l, r) => Err(DataFusionError::NotImplemented(format!( "Invalid arguments for binary expression: {:?} and {:?}", @@ -507,10 +589,10 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16))))) } Some(LiteralType::I32(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n as i32))))) + Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n))))) } Some(LiteralType::I64(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n as i64))))) + Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n))))) } Some(LiteralType::Boolean(b)) => { Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b))))) @@ -524,12 +606,12 @@ pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, extensi Some(LiteralType::Fp64(f)) => { Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) } - Some(LiteralType::String(s)) => Ok(Arc::new(Expr::Literal(ScalarValue::Utf8( - Some(s.clone()), - )))), - Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some( - b.clone(), - ))))), + Some(LiteralType::String(s)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone()))))) + } + Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal( + ScalarValue::Binary(Some(b.clone())), + ))), _ => { return Err(DataFusionError::NotImplemented(format!( "Unsupported literal_type: {:?}", diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs index 78532046bced..ab32983b9b99 100644 --- a/datafusion/substrait/src/producer.rs +++ b/datafusion/substrait/src/producer.rs @@ -19,11 +19,16 @@ use std::collections::HashMap; use datafusion::{ error::{DataFusionError, Result}, - logical_plan::{DFSchemaRef, Expr, JoinConstraint, LogicalPlan, Operator}, prelude::JoinType, scalar::ScalarValue, }; +use datafusion::common::DFSchemaRef; +#[allow(unused_imports)] +use datafusion::logical_expr::aggregate_function; +use datafusion::logical_expr::expr::{BinaryExpr, Case, Sort}; +use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; +use datafusion::prelude::{binary_expr, Expr}; use substrait::protobuf::{ aggregate_function::AggregationInvocation, aggregate_rel::{Grouping, Measure}, @@ -32,36 +37,37 @@ use substrait::protobuf::{ if_then::IfClause, literal::LiteralType, mask_expression::{StructItem, StructSelect}, - reference_segment, - FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, ScalarFunction, + reference_segment, FieldReference, IfThen, Literal, MaskExpression, + ReferenceSegment, RexType, ScalarFunction, + }, + extensions::{ + self, + simple_extension_declaration::{ExtensionFunction, MappingType}, }, - extensions::{self, simple_extension_declaration::{MappingType, ExtensionFunction}}, function_argument::ArgType, plan_rel, read_rel::{NamedTable, ReadType}, rel::RelType, - sort_field::{ - SortDirection, - SortKind, - }, - AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, ProjectRel, ReadRel, SortField, SortRel, - PlanRel, - Plan, Rel, RelRoot, AggregateFunction, + sort_field::{SortDirection, SortKind}, + AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, + JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SortField, + SortRel, }; /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { // Parse relation nodes - let mut extension_info: (Vec, HashMap) = (vec![], HashMap::new()); + let mut extension_info: ( + Vec, + HashMap, + ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported let plan_rels = vec![PlanRel { - rel_type: Some(plan_rel::RelType::Root( - RelRoot { - input: Some(*to_substrait_rel(plan, &mut extension_info)?), - names: plan.schema().field_names(), - } - )) + rel_type: Some(plan_rel::RelType::Root(RelRoot { + input: Some(*to_substrait_rel(plan, &mut extension_info)?), + names: plan.schema().field_names(), + })), }]; let (function_extensions, _) = extension_info; @@ -74,11 +80,16 @@ pub fn to_substrait_plan(plan: &LogicalPlan) -> Result> { advanced_extensions: None, expected_type_urls: vec![], })) - } /// Convert DataFusion LogicalPlan to Substrait Rel -pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec, HashMap)) -> Result> { +pub fn to_substrait_rel( + plan: &LogicalPlan, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result> { match plan { LogicalPlan::TableScan(scan) => { let projection = scan.projection.as_ref().map(|p| { @@ -138,7 +149,11 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec { let input = to_substrait_rel(filter.input.as_ref(), extension_info)?; - let filter_expr = to_substrait_rex(&filter.predicate, filter.input.schema(), extension_info)?; + let filter_expr = to_substrait_rex( + &filter.predicate, + filter.input.schema(), + extension_info, + )?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { common: None, @@ -150,10 +165,7 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec { let input = to_substrait_rel(limit.input.as_ref(), extension_info)?; - let limit_fetch = match limit.fetch { - Some(count) => count, - None => 0, - }; + let limit_fetch = limit.fetch.unwrap_or(0); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, @@ -193,13 +205,15 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec>>()?; - + Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), - groupings: vec![Grouping { grouping_expressions: grouping }], //groupings, - measures: measures, + groupings: vec![Grouping { + grouping_expressions: grouping, + }], //groupings, + measures, advanced_extension: None, }))), })) @@ -209,14 +223,16 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), - groupings: vec![Grouping { grouping_expressions: grouping }], + groupings: vec![Grouping { + grouping_expressions: grouping, + }], measures: vec![], advanced_extension: None, }))), @@ -230,8 +246,9 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec 2, JoinType::Right => 3, JoinType::Full => 4, - JoinType::Anti => 5, - JoinType::Semi => 6, + JoinType::LeftAnti => 5, + JoinType::LeftSemi => 6, + _ => panic!(), // TODO }; // we only support basic joins so return an error for anything not yet supported if join.null_equals_null { @@ -251,14 +268,11 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec = join + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let join_expression = join .on .iter() - .map(|(l, r)| Expr::Column(l.clone()).eq(Expr::Column(r.clone()))) - .collect(); - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let join_expression = join_expression - .into_iter() + .map(|(l, r)| binary_expr(l.clone(), Operator::Eq, r.clone())) .reduce(|acc: Expr, expr: Expr| acc.and(expr)); if let Some(e) = join_expression { Ok(Box::new(Rel { @@ -267,7 +281,11 @@ pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut (Vec &'static str { Operator::Modulo => "mod", Operator::And => "and", Operator::Or => "or", - Operator::Like => "like", - Operator::NotLike => "not_like", Operator::IsDistinctFrom => "is_distinct_from", Operator::IsNotDistinctFrom => "is_not_distinct_from", Operator::RegexMatch => "regex_match", @@ -322,9 +338,17 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } -pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +#[allow(deprecated)] +pub fn to_substrait_agg_measure( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::AggregateFunction { fun, args, distinct, filter } => { + Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter }) => { let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); @@ -334,7 +358,7 @@ pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_inf Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, - arguments: arguments, + arguments, sorts: vec![], output_type: None, invocation: match distinct { @@ -357,7 +381,13 @@ pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, extension_inf } } -fn _register_function(function_name: String, extension_info: &mut (Vec, HashMap)) -> u32 { +fn _register_function( + function_name: String, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> u32 { let (function_extensions, function_set) = extension_info; let function_name = function_name.to_lowercase(); // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, @@ -368,7 +398,7 @@ fn _register_function(function_name: String, extension_info: &mut (Vec { // Function has been registered *function_anchor - }, + } None => { // Function has NOT been registered let function_anchor = function_set.len() as u32; @@ -376,7 +406,7 @@ fn _register_function(function_name: String, extension_info: &mut (Vec, HashMap)) -> Expression { +#[allow(deprecated)] +pub fn make_binary_op_scalar_func( + lhs: &Expression, + rhs: &Expression, + op: Operator, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Expression { let function_name = operator_to_name(op).to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); Expression { @@ -414,45 +452,92 @@ pub fn make_binary_op_scalar_func(lhs: &Expression, rhs: &Expression, op: Operat } /// Convert DataFusion Expr to Substrait Rex -pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +pub fn to_substrait_rex( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::Between { expr, negated, low, high } => { + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; let substrait_low = to_substrait_rex(low, schema, extension_info)?; let substrait_high = to_substrait_rex(high, schema, extension_info)?; - let l_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_low, Operator::Lt, extension_info); - let r_expr = make_binary_op_scalar_func(&substrait_high, &substrait_expr, Operator::Lt, extension_info); + let l_expr = make_binary_op_scalar_func( + &substrait_expr, + &substrait_low, + Operator::Lt, + extension_info, + ); + let r_expr = make_binary_op_scalar_func( + &substrait_high, + &substrait_expr, + Operator::Lt, + extension_info, + ); - Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::Or, extension_info)) + Ok(make_binary_op_scalar_func( + &l_expr, + &r_expr, + Operator::Or, + extension_info, + )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; let substrait_low = to_substrait_rex(low, schema, extension_info)?; let substrait_high = to_substrait_rex(high, schema, extension_info)?; - let l_expr = make_binary_op_scalar_func(&substrait_low, &substrait_expr, Operator::LtEq, extension_info); - let r_expr = make_binary_op_scalar_func(&substrait_expr, &substrait_high, Operator::LtEq, extension_info); + let l_expr = make_binary_op_scalar_func( + &substrait_low, + &substrait_expr, + Operator::LtEq, + extension_info, + ); + let r_expr = make_binary_op_scalar_func( + &substrait_expr, + &substrait_high, + Operator::LtEq, + extension_info, + ); - Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::And, extension_info)) + Ok(make_binary_op_scalar_func( + &l_expr, + &r_expr, + Operator::And, + extension_info, + )) } } Expr::Column(col) => { - let index = schema.index_of_column(&col)?; + let index = schema.index_of_column(col)?; substrait_field_ref(index) } - Expr::BinaryExpr { left, op, right } => { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let l = to_substrait_rex(left, schema, extension_info)?; let r = to_substrait_rex(right, schema, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } - Expr::Case { expr, when_then_expr, else_expr } => { + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => { let mut ifs: Vec = vec![]; // Parse base - if let Some(e) = expr { // Base expression exists + if let Some(e) = expr { + // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex(e, schema, extension_info)?), then: None, @@ -471,12 +556,9 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), None => None, }; - + Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { - ifs: ifs, - r#else: r#else - }))), + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), }) } Expr::Literal(value) => { @@ -508,9 +590,7 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut })), }) } - Expr::Alias(expr, _alias) => { - to_substrait_rex(expr, schema, extension_info) - } + Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported expression: {:?}", expr @@ -518,9 +598,20 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut } } -fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut (Vec, HashMap)) -> Result { +fn substrait_sort_field( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { match expr { - Expr::Sort { expr, asc, nulls_first } => { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => { let e = to_substrait_rex(expr, schema, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, @@ -532,7 +623,7 @@ fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: &mut expr: Some(e), sort_kind: Some(SortKind::Direction(d as i32)), }) - }, + } _ => Err(DataFusionError::NotImplemented(format!( "Expecting sort expression but got {:?}", expr diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 7f52077f1be9..d71a30d76a60 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -24,19 +24,16 @@ use prost::Message; use substrait::protobuf::Plan; use std::fs::OpenOptions; -use std::io::{Write, Read}; +use std::io::{Read, Write}; pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let df = ctx.sql(sql).await?; - let plan = df.to_logical_plan()?; + let plan = df.into_optimized_plan()?; let proto = producer::to_substrait_plan(&plan)?; let mut protobuf_out = Vec::::new(); proto.encode(&mut protobuf_out).unwrap(); - let mut file = OpenOptions::new() - .create(true) - .write(true) - .open(path)?; + let mut file = OpenOptions::new().create(true).write(true).open(path)?; file.write_all(&protobuf_out)?; Ok(()) } @@ -44,14 +41,10 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() pub async fn deserialize(path: &str) -> Result> { let mut protobuf_in = Vec::::new(); - let mut file = OpenOptions::new() - .read(true) - .open(path)?; + let mut file = OpenOptions::new().read(true).open(path)?; file.read_to_end(&mut protobuf_in)?; let proto = Message::decode(&*protobuf_in).unwrap(); Ok(Box::new(proto)) } - - diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs index 21a3a5f291a3..5fde79d4cc76 100644 --- a/datafusion/substrait/tests/roundtrip.rs +++ b/datafusion/substrait/tests/roundtrip.rs @@ -45,7 +45,8 @@ mod tests { async fn select_with_reused_functions() -> Result<()> { let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; + let (mut function_names, mut function_anchors) = + function_extension_info(sql).await?; function_names.sort(); function_anchors.sort(); @@ -82,7 +83,10 @@ mod tests { #[tokio::test] async fn aggregate_distinct_with_having() -> Result<()> { - roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100").await + roundtrip( + "SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100", + ) + .await } #[tokio::test] @@ -95,7 +99,8 @@ mod tests { test_alias( "SELECT * FROM (SELECT distinct a FROM data)", // `SELECT *` is used to add `projection` at the root "SELECT a FROM data GROUP BY a", - ).await + ) + .await } #[tokio::test] @@ -103,15 +108,13 @@ mod tests { test_alias( "SELECT * FROM (SELECT distinct a, b FROM data)", // `SELECT *` is used to add `projection` at the root "SELECT a, b FROM data GROUP BY a, b", - ).await + ) + .await } #[tokio::test] async fn simple_alias() -> Result<()> { - test_alias( - "SELECT d1.a, d1.b FROM data d1", - "SELECT a, b FROM data", - ).await + test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await } #[tokio::test] @@ -127,7 +130,7 @@ mod tests { async fn between_integers() -> Result<()> { test_alias( "SELECT * FROM data WHERE a BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a >= 2 AND a <= 6" + "SELECT * FROM data WHERE a >= 2 AND a <= 6", ) .await } @@ -136,23 +139,29 @@ mod tests { async fn not_between_integers() -> Result<()> { test_alias( "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a < 2 OR a > 6" + "SELECT * FROM data WHERE a < 2 OR a > 6", ) .await } #[tokio::test] async fn case_without_base_expression() -> Result<()> { - roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data").await + roundtrip( + "SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data", + ) + .await } #[tokio::test] async fn case_with_base_expression() -> Result<()> { - roundtrip("SELECT (CASE a + roundtrip( + "SELECT (CASE a WHEN 0 THEN 'zero' WHEN 1 THEN 'one' ELSE 'other' - END) FROM data").await + END) FROM data", + ) + .await } #[tokio::test] @@ -175,10 +184,10 @@ mod tests { async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?; - let plan = df.to_logical_plan()?; + let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan)?; let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = df.into_optimized_plan()?; let plan2str = format!("{:?}", plan2); assert_eq!(expected_plan_str, &plan2str); Ok(()) @@ -187,11 +196,11 @@ mod tests { async fn roundtrip_fill_na(sql: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?; - let plan1 = df.to_logical_plan()?; + let plan1 = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan1)?; let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = df.into_optimized_plan()?; // Format plan string and replace all None's with 0 let plan1str = format!("{:?}", plan1).replace("None", "0"); @@ -208,12 +217,16 @@ mod tests { let mut ctx = create_context().await?; let df_a = ctx.sql(sql_with_alias).await?; - let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?; - let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?.to_logical_plan()?; + let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?)?; + let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a) + .await? + .into_optimized_plan()?; let df = ctx.sql(sql_no_alias).await?; - let proto = to_substrait_plan(&df.to_logical_plan()?)?; - let plan = from_substrait_plan(&mut ctx, &proto).await?.to_logical_plan()?; + let proto = to_substrait_plan(&df.into_optimized_plan()?)?; + let plan = from_substrait_plan(&mut ctx, &proto) + .await? + .into_optimized_plan()?; println!("{:#?}", plan_with_alias); println!("{:#?}", plan); @@ -227,11 +240,11 @@ mod tests { async fn roundtrip(sql: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?; - let plan = df.to_logical_plan()?; + let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan)?; let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = df.to_logical_plan()?; + let plan2 = df.into_optimized_plan()?; println!("{:#?}", plan); println!("{:#?}", plan2); @@ -242,23 +255,26 @@ mod tests { Ok(()) } - async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { + async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; - let plan = df.to_logical_plan()?; + let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan)?; let mut function_names: Vec = vec![]; let mut function_anchors: Vec = vec![]; for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { - MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension") + let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() + { + MappingType::ExtensionFunction(ext_f) => { + (ext_f.function_anchor, &ext_f.name) + } + _ => unreachable!("Producer does not generate a non-function extension"), }; function_names.push(function_name.to_string()); function_anchors.push(function_anchor); } - + Ok((function_names, function_anchors)) } diff --git a/datafusion/substrait/tests/serialize.rs b/datafusion/substrait/tests/serialize.rs index 505c4f5f4ec4..59b2899ede39 100644 --- a/datafusion/substrait/tests/serialize.rs +++ b/datafusion/substrait/tests/serialize.rs @@ -33,7 +33,7 @@ mod tests { let sql = "SELECT a, b FROM data"; // Test reference let df_ref = ctx.sql(sql).await?; - let plan_ref = df_ref.to_logical_plan()?; + let plan_ref = df_ref.into_optimized_plan()?; // Test // Write substrait plan to file serializer::serialize(sql, &ctx, &path).await?; @@ -41,7 +41,7 @@ mod tests { let proto = serializer::deserialize(path).await?; // Check plan equality let df = from_substrait_plan(&mut ctx, &proto).await?; - let plan = df.to_logical_plan()?; + let plan = df.into_optimized_plan()?; let plan_str_ref = format!("{:?}", plan_ref); let plan_str = format!("{:?}", plan); assert_eq!(plan_str_ref, plan_str); @@ -59,4 +59,4 @@ mod tests { .await?; Ok(ctx) } -} \ No newline at end of file +}