From 4c898b45720efed56f15f8030e8ca2c1e5f6ec1a Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Sun, 5 Jan 2025 10:26:09 -0800 Subject: [PATCH] feat(substrait): modular substrait producer (#13931) * feat(substrait): modular substrait producer * refactor(substrait): simplify col_ref_offset handling in producer * refactor(substrait): remove column offset tracking from producer * docs(substrait): document SubstraitProducer * refactor: minor cleanup * feature: remove unused SubstraitPlanningState BREAKING CHANGE: SubstraitPlanningState is no longer available * refactor: cargo fmt * refactor(substrait): consume_ -> handle_ * refactor(substrait): expand match blocks * refactor: DefaultSubstraitProducer only needs serializer_registry * refactor: remove unnecessary warning suppression * fix(substrait): route expr conversion through handle_expr * cargo fmt --- .../substrait/src/logical_plan/consumer.rs | 3 + datafusion/substrait/src/logical_plan/mod.rs | 1 - .../substrait/src/logical_plan/producer.rs | 2260 ++++++++++------- .../substrait/src/logical_plan/state.rs | 63 - .../tests/cases/roundtrip_logical_plan.rs | 15 + 5 files changed, 1295 insertions(+), 1047 deletions(-) delete mode 100644 datafusion/substrait/src/logical_plan/state.rs diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 0ee87afe3286..9623f12c88dd 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -114,6 +114,9 @@ use substrait::proto::{ /// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. /// It can be implemented by users to allow for custom handling of relations, expressions, etc. /// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// /// # Example Usage /// /// ``` diff --git a/datafusion/substrait/src/logical_plan/mod.rs b/datafusion/substrait/src/logical_plan/mod.rs index 9e2fa9fa49de..6f8b8e493f52 100644 --- a/datafusion/substrait/src/logical_plan/mod.rs +++ b/datafusion/substrait/src/logical_plan/mod.rs @@ -17,4 +17,3 @@ pub mod consumer; pub mod producer; -pub mod state; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index b73d246e1989..e501ddf5c698 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -22,7 +22,11 @@ use std::sync::Arc; use substrait::proto::expression_reference::ExprType; use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{Distinct, Like, Partitioning, TryCast, WindowFrameUnits}; +use datafusion::logical_expr::{ + Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, + Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, + TryCast, Union, Values, Window, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -43,11 +47,12 @@ use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, DFSchema, DFSchemaRef, ToDFSchema, + substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, }; -#[allow(unused_imports)] +use datafusion::execution::registry::SerializerRegistry; +use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, + Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -63,6 +68,7 @@ use substrait::proto::expression::literal::{ }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::expression::ScalarFunction; use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; @@ -84,8 +90,7 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, Subquery, - WindowFunction as SubstraitWindowFunction, + SingularOrList, WindowFunction as SubstraitWindowFunction, }, function_argument::ArgType, join_rel, plan_rel, r#type, @@ -101,14 +106,329 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn handle_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_plan(&mut self, plan: &LogicalPlan) -> Result> { + to_substrait_rel(self, plan) + } + + fn handle_projection(&mut self, plan: &Projection) -> Result> { + from_projection(self, plan) + } + + fn handle_filter(&mut self, plan: &Filter) -> Result> { + from_filter(self, plan) + } + + fn handle_window(&mut self, plan: &Window) -> Result> { + from_window(self, plan) + } + + fn handle_aggregate(&mut self, plan: &Aggregate) -> Result> { + from_aggregate(self, plan) + } + + fn handle_sort(&mut self, plan: &Sort) -> Result> { + from_sort(self, plan) + } + + fn handle_join(&mut self, plan: &Join) -> Result> { + from_join(self, plan) + } + + fn handle_repartition(&mut self, plan: &Repartition) -> Result> { + from_repartition(self, plan) + } + + fn handle_union(&mut self, plan: &Union) -> Result> { + from_union(self, plan) + } + + fn handle_table_scan(&mut self, plan: &TableScan) -> Result> { + from_table_scan(self, plan) + } + + fn handle_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { + from_empty_relation(plan) + } + + fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { + from_subquery_alias(self, plan) + } + + fn handle_limit(&mut self, plan: &Limit) -> Result> { + from_limit(self, plan) + } + + fn handle_values(&mut self, plan: &Values) -> Result> { + from_values(self, plan) + } + + fn handle_distinct(&mut self, plan: &Distinct) -> Result> { + from_distinct(self, plan) + } + + fn handle_extension(&mut self, _plan: &Extension) -> Result> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { + to_substrait_rex(self, expr, schema) + } + + fn handle_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> Result { + from_alias(self, alias, schema) + } + + fn handle_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> Result { + from_column(column, schema) + } + + fn handle_literal(&mut self, value: &ScalarValue) -> Result { + from_literal(self, value) + } + + fn handle_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> Result { + from_binary_expr(self, expr, schema) + } + + fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn handle_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> Result { + from_unary_expr(self, expr, schema) + } + + fn handle_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> Result { + from_between(self, between, schema) + } + + fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { + from_case(self, case, schema) + } + + fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { + from_cast(self, cast, schema) + } + + fn handle_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> Result { + from_try_cast(self, cast, schema) + } + + fn handle_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> Result { + from_scalar_function(self, scalar_fn, schema) + } + + fn handle_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> Result { + from_aggregate_function(self, agg_fn, schema) + } + + fn handle_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> Result { + from_window_function(self, window_fn, schema) + } + + fn handle_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> Result { + from_in_list(self, in_list, schema) + } + + fn handle_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> Result { + from_in_subquery(self, in_subquery, schema) + } +} + +struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + serializer_registry: &'a dyn SerializerRegistry, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + serializer_registry: state.serializer_registry().as_ref(), + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn handle_extension(&mut self, plan: &Extension) -> Result> { + let extension_bytes = self + .serializer_registry + .serialize_logical_plan(plan.node.as_ref())?; + let detail = ProtoAny { + type_url: plan.node.name().to_string(), + value: extension_bytes.into(), + }; + let mut inputs_rel = plan + .node + .inputs() + .into_iter() + .map(|plan| self.handle_plan(plan)) + .collect::>>()?; + let rel_type = match inputs_rel.len() { + 0 => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: Some(detail), + }), + 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: Some(detail), + input: Some(inputs_rel.pop().unwrap()), + })), + _ => RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: Some(detail), + inputs: inputs_rel.into_iter().map(|r| *r).collect(), + }), + }; + Ok(Box::new(Rel { + rel_type: Some(rel_type), + })) + } +} /// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan( - plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, -) -> Result> { - let mut extensions = Extensions::default(); +pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result> { // Parse relation nodes // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported @@ -117,14 +437,16 @@ pub fn to_substrait_plan( let plan = Arc::new(ExpandWildcardRule::new()) .analyze(plan.clone(), &ConfigOptions::default())?; + let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(&plan, state, &mut extensions)?), + input: Some(*producer.handle_plan(&plan)?), names: to_substrait_named_struct(plan.schema())?.names, })), }]; // Return parsed plan + let extensions = producer.get_extensions(); Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], @@ -150,20 +472,13 @@ pub fn to_substrait_plan( pub fn to_substrait_extended_expr( exprs: &[(&Expr, &Field)], schema: &DFSchemaRef, - state: &dyn SubstraitPlanningState, + state: &SessionState, ) -> Result> { - let mut extensions = Extensions::default(); - + let mut producer = DefaultSubstraitProducer::new(state); let substrait_exprs = exprs .iter() .map(|(expr, field)| { - let substrait_expr = to_substrait_rex( - state, - expr, - schema, - /*col_ref_offset=*/ 0, - &mut extensions, - )?; + let substrait_expr = producer.handle_expr(expr, schema)?; let mut output_names = Vec::new(); flatten_names(field, false, &mut output_names)?; Ok(ExpressionReference { @@ -174,6 +489,7 @@ pub fn to_substrait_extended_expr( .collect::>>()?; let substrait_schema = to_substrait_named_struct(schema)?; + let extensions = producer.get_extensions(); Ok(Box::new(ExtendedExpression { advanced_extensions: None, expected_type_urls: vec![], @@ -185,257 +501,303 @@ pub fn to_substrait_extended_expr( })) } -/// Convert DataFusion LogicalPlan to Substrait Rel -#[allow(deprecated)] pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, - extensions: &mut Extensions, ) -> Result> { match plan { - LogicalPlan::TableScan(scan) => { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); + LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Filter(plan) => producer.handle_filter(plan), + LogicalPlan::Window(plan) => producer.handle_window(plan), + LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), + LogicalPlan::Sort(plan) => producer.handle_sort(plan), + LogicalPlan::Join(plan) => producer.handle_join(plan), + LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), + LogicalPlan::Union(plan) => producer.handle_union(plan), + LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Values(plan) => producer.handle_values(plan), + LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Extension(plan) => producer.handle_extension(plan), + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + } +} - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); +pub fn from_table_scan( + _producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> Result> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: None, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; +pub fn from_empty_relation(e: &EmptyRelation) -> Result> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(base_schema), - filter: None, - best_effort_filter: None, - projection, - advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), - }))), - })) - } - LogicalPlan::EmptyRelation(e) => { - if e.produce_one_row { - return not_impl_err!( - "Producing a row from empty relation is unsupported" - ); - } - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values: vec![], - expressions: vec![], - })), - }))), - })) - } - LogicalPlan::Values(v) => { - let values = v - .values +pub fn from_values( + producer: &mut impl SubstraitProducer, + v: &Values, +) -> Result> { + let values = v + .values + .iter() + .map(|row| { + let fields = row .iter() - .map(|row| { - let fields = row - .iter() - .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv, extensions), - Expr::Alias(alias) => match alias.expr.as_ref() { - // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv, extensions), - _ => Err(substrait_datafusion_err!( + .map(|v| match v { + Expr::Literal(sv) => to_substrait_literal(producer, sv), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv) => to_substrait_literal(producer, sv), + _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), - }, - _ => Err(substrait_datafusion_err!( + }, + _ => Err(substrait_datafusion_err!( "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() )), - }) - .collect::>()?; - Ok(Struct { fields }) }) .collect::>()?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values, - expressions: vec![], - })), - }))), - })) - } - LogicalPlan::Projection(p) => { - let expressions = p - .expr - .iter() - .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, extensions)) - .collect::>>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) +} - let emit_kind = create_project_remapping( - expressions.len(), - p.input.as_ref().schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; +pub fn from_projection( + producer: &mut impl SubstraitProducer, + p: &Projection, +) -> Result> { + let expressions = p + .expr + .iter() + .map(|e| producer.handle_expr(e, p.input.schema())) + .collect::>>()?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: Some(common), - input: Some(to_substrait_rel(p.input.as_ref(), state, extensions)?), - expressions, - advanced_extension: None, - }))), - })) - } - LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; - let filter_expr = to_substrait_rex( - state, - &filter.predicate, - filter.input.schema(), - 0, - extensions, - )?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Filter(Box::new(FilterRel { - common: None, - input: Some(input), - condition: Some(Box::new(filter_expr)), - advanced_extension: None, - }))), - })) - } - LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; - let empty_schema = Arc::new(DFSchema::empty()); - let offset_mode = limit - .skip - .as_ref() - .map(|expr| { - to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::OffsetMode::OffsetExpr); - let count_mode = limit - .fetch - .as_ref() - .map(|expr| { - to_substrait_rex(state, expr.as_ref(), &empty_schema, 0, extensions) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(producer.handle_plan(p.input.as_ref())?), + expressions, + advanced_extension: None, + }))), + })) +} + +pub fn from_filter( + producer: &mut impl SubstraitProducer, + filter: &Filter, +) -> Result> { + let input = producer.handle_plan(filter.input.as_ref())?; + let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) +} + +pub fn from_limit( + producer: &mut impl SubstraitProducer, + limit: &Limit, +) -> Result> { + let input = producer.handle_plan(limit.input.as_ref())?; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset_mode, + count_mode, + advanced_extension: None, + }))), + })) +} + +pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result> { + let Sort { expr, input, fetch } = sort; + let sort_fields = expr + .iter() + .map(|e| substrait_sort_field(producer, e, input.schema())) + .collect::>>()?; + + let input = producer.handle_plan(input.as_ref())?; + + let sort_rel = Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, - input: Some(input), - offset_mode, + input: Some(sort_rel), + offset_mode: None, count_mode, advanced_extension: None, }))), })) } - LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { - let sort_fields = expr - .iter() - .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) - .collect::>>()?; + None => Ok(sort_rel), + } +} - let input = to_substrait_rel(input.as_ref(), state, extensions)?; +pub fn from_aggregate( + producer: &mut impl SubstraitProducer, + agg: &Aggregate, +) -> Result> { + let input = producer.handle_plan(agg.input.as_ref())?; + let (grouping_expressions, groupings) = + to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) + .collect::>>()?; - let sort_rel = Box::new(Rel { - rel_type: Some(RelType::Sort(Box::new(SortRel { - common: None, - input: Some(input), - sorts: sort_fields, - advanced_extension: None, - }))), - }); - - match fetch { - Some(amount) => { - let count_mode = - Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::I64(*amount as i64)), - })), - }))); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(sort_rel), - offset_mode: None, - count_mode, - advanced_extension: None, - }))), - })) - } - None => Ok(sort_rel), - } - } - LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; - let (grouping_expressions, groupings) = to_substrait_groupings( - state, - &agg.group_expr, - agg.input.schema(), - extensions, - )?; - let measures = agg - .aggr_expr - .iter() - .map(|e| { - to_substrait_agg_measure(state, e, agg.input.schema(), extensions) - }) - .collect::>>()?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions, + groupings, + measures, + advanced_extension: None, + }))), + })) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions, - groupings, - measures, - advanced_extension: None, - }))), - })) - } - LogicalPlan::Distinct(Distinct::All(plan)) => { +pub fn from_distinct( + producer: &mut impl SubstraitProducer, + distinct: &Distinct, +) -> Result> { + match distinct { + Distinct::All(plan) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), state, extensions)?; + let input = producer.handle_plan(plan.as_ref())?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; + #[allow(deprecated)] Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, @@ -450,220 +812,176 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; - let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => { - return not_impl_err!("join constraint: `using`") - } - } - // parse filter if exists - let in_join_schema = join.left.schema().join(join.right.schema())?; - let join_filter = match &join.filter { - Some(filter) => Some(to_substrait_rex( - state, - filter, - &Arc::new(in_join_schema), - 0, - extensions, - )?), - None => None, - }; + Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), + } +} - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - let join_on = to_substrait_join_expr( - state, - &join.on, - eq_op, - join.left.schema(), - join.right.schema(), - extensions, - )?; - - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - on_expr, - filter, - Operator::And, - extensions, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, - }; +pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); + // we only support basic joins so return an error for anything not yet supported + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), + } + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: join_expr, - post_join_filter: None, - advanced_extension: None, - }))), - })) - } - LogicalPlan::SubqueryAlias(alias) => { - // Do nothing if encounters SubqueryAlias - // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), state, extensions) - } - LogicalPlan::Union(union) => { - let input_rels = union - .inputs - .iter() - .map(|input| to_substrait_rel(input.as_ref(), state, extensions)) - .collect::>>()? - .into_iter() - .map(|ptr| *ptr) - .collect(); - Ok(Box::new(Rel { - rel_type: Some(RelType::Set(SetRel { - common: None, - inputs: input_rels, - op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL - advanced_extension: None, - })), - })) - } - LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; + // convert filter if present + let join_filter = match &join.filter { + Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), + None => None, + }; - // create a field reference for each input field - let mut expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; + // map the left and right columns to binary expressions in the form `l = r` + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let eq_op = if join.null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + producer, + on_expr, + filter, + Operator::And, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; - // process and add each window function expression - for expr in &window.window_expr { - expressions.push(to_substrait_rex( - state, - expr, - window.input.schema(), - 0, - extensions, - )?); - } + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) +} - let emit_kind = create_project_remapping( - expressions.len(), - window.input.schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - let project_rel = Box::new(ProjectRel { - common: Some(common), - input: Some(input), - expressions, - advanced_extension: None, - }); +pub fn from_subquery_alias( + producer: &mut impl SubstraitProducer, + alias: &SubqueryAlias, +) -> Result> { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + producer.handle_plan(alias.input.as_ref()) +} - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(project_rel)), - })) +pub fn from_union( + producer: &mut impl SubstraitProducer, + union: &Union, +) -> Result> { + let input_rels = union + .inputs + .iter() + .map(|input| producer.handle_plan(input.as_ref())) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) +} + +pub fn from_window( + producer: &mut impl SubstraitProducer, + window: &Window, +) -> Result> { + let input = producer.handle_plan(window.input.as_ref())?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression + for expr in &window.window_expr { + expressions.push(producer.handle_expr(expr, window.input.schema())?); + } + + let emit_kind = + create_project_remapping(expressions.len(), window.input.schema().fields().len()); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) +} + +pub fn from_repartition( + producer: &mut impl SubstraitProducer, + repartition: &Repartition, +) -> Result> { + let input = producer.handle_plan(repartition.input.as_ref())?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) } - LogicalPlan::Repartition(repartition) => { - let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; - let partition_count = match repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(num) => num, - Partitioning::Hash(_, num) => num, - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let exchange_kind = match &repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(_) => { - ExchangeKind::RoundRobin(RoundRobin::default()) - } - Partitioning::Hash(exprs, _) => { - let fields = exprs - .iter() - .map(|e| { - try_to_substrait_field_reference( - e, - repartition.input.schema(), - ) - }) - .collect::>>()?; - ExchangeKind::ScatterByFields(ScatterFields { fields }) - } - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - let exchange_rel = ExchangeRel { - common: None, - input: Some(input), - exchange_kind: Some(exchange_kind), - advanced_extension: None, - partition_count: partition_count as i32, - targets: vec![], - }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), - })) + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) } - LogicalPlan::Extension(extension_plan) => { - let extension_bytes = state - .serializer_registry() - .serialize_logical_plan(extension_plan.node.as_ref())?; - let detail = ProtoAny { - type_url: extension_plan.node.name().to_string(), - value: extension_bytes.into(), - }; - let mut inputs_rel = extension_plan - .node - .inputs() - .into_iter() - .map(|plan| to_substrait_rel(plan, state, extensions)) + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) .collect::>>()?; - let rel_type = match inputs_rel.len() { - 0 => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: Some(detail), - }), - 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: Some(detail), - input: Some(inputs_rel.pop().unwrap()), - })), - _ => RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: Some(detail), - inputs: inputs_rel.into_iter().map(|r| *r).collect(), - }), - }; - Ok(Box::new(Rel { - rel_type: Some(rel_type), - })) + ExchangeKind::ScatterByFields(ScatterFields { fields }) } - _ => not_impl_err!("Unsupported operator: {plan}"), - } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) } /// By default, a Substrait Project outputs all input fields followed by all expressions. @@ -730,32 +1048,23 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } fn to_substrait_join_expr( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, - left_schema: &DFSchemaRef, - right_schema: &DFSchemaRef, - extensions: &mut Extensions, + join_schema: &DFSchemaRef, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { - // Parse left - let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; - // Parse right - let r = to_substrait_rex( - state, - right, - right_schema, - left_schema.fields().len(), // offset to return the correct index - extensions, - )?; + let l = producer.handle_expr(left, join_schema)?; + let r = producer.handle_expr(right, join_schema)?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); + exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); } + let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) + make_binary_op_scalar_func(producer, &acc, &e, Operator::And) }); Ok(join_expr) } @@ -811,23 +1120,22 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } -#[allow(deprecated)] pub fn parse_flat_grouping_exprs( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, exprs: &[Expr], schema: &DFSchemaRef, - extensions: &mut Extensions, ref_group_exprs: &mut Vec, ) -> Result { let mut expression_references = vec![]; let mut grouping_expressions = vec![]; for e in exprs { - let rex = to_substrait_rex(state, e, schema, 0, extensions)?; + let rex = producer.handle_expr(e, schema)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); } + #[allow(deprecated)] Ok(Grouping { grouping_expressions, expression_references, @@ -835,10 +1143,9 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, exprs: &[Expr], schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result<(Vec, Vec)> { let mut ref_group_exprs = vec![]; let groupings = match exprs.len() { @@ -851,10 +1158,9 @@ pub fn to_substrait_groupings( .iter() .map(|set| { parse_flat_grouping_exprs( - state, + producer, set, schema, - extensions, &mut ref_group_exprs, ) }) @@ -869,10 +1175,9 @@ pub fn to_substrait_groupings( .rev() .map(|set| { parse_flat_grouping_exprs( - state, + producer, set, schema, - extensions, &mut ref_group_exprs, ) }) @@ -880,66 +1185,81 @@ pub fn to_substrait_groupings( } }, _ => Ok(vec![parse_flat_grouping_exprs( - state, + producer, exprs, schema, - extensions, &mut ref_group_exprs, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - state, + producer, exprs, schema, - extensions, &mut ref_group_exprs, )?]), }?; Ok((ref_group_exprs, groupings)) } -#[allow(deprecated)] +pub fn from_aggregate_function( + producer: &mut impl SubstraitProducer, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, +) -> Result { + let expr::AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + } = agg_fn; + let sorts = if let Some(order_by) = order_by { + order_by + .iter() + .map(|expr| to_substrait_sort_field(producer, expr, schema)) + .collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + let function_anchor = producer.register_function(func.name().to_string()); + #[allow(deprecated)] + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(producer.handle_expr(f, schema)?), + None => None, + }, + }) +} + pub fn to_substrait_agg_measure( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) }); - } - let function_anchor = extensions.register_function(func.name().to_string()); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(state, f, schema, 0, extensions)?), - None => None - } - }) - - } - Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(state, expr, schema, extensions) + Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), + Expr::Alias(Alias { expr, .. }) => { + to_substrait_agg_measure(producer, expr, schema) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -951,10 +1271,9 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( - state: &dyn SubstraitPlanningState, - sort: &Sort, + producer: &mut impl SubstraitProducer, + sort: &expr::Sort, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { let sort_kind = match (sort.asc, sort.nulls_first) { (true, true) => SortDirection::AscNullsFirst, @@ -963,20 +1282,20 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, extensions)?), + expr: Some(producer.handle_expr(&sort.expr, schema)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } /// Return Substrait scalar function with two arguments -#[allow(deprecated)] pub fn make_binary_op_scalar_func( + producer: &mut impl SubstraitProducer, lhs: &Expression, rhs: &Expression, op: Operator, - extensions: &mut Extensions, ) -> Expression { - let function_anchor = extensions.register_function(operator_to_name(op).to_string()); + let function_anchor = producer.register_function(operator_to_name(op).to_string()); + #[allow(deprecated)] Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -998,450 +1317,431 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } + Expr::Alias(expr) => producer.handle_alias(expr, schema), + Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - - let function_anchor = extensions.register_function(fun.name().to_string()); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) + Expr::Literal(expr) => producer.handle_literal(expr), + Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), + Expr::Like(expr) => producer.handle_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Not(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::Negative(_) => producer.handle_unary_expr(expr, schema), + Expr::Between(expr) => producer.handle_between(expr, schema), + Expr::Case(expr) => producer.handle_case(expr, schema), + Expr::Cast(expr) => producer.handle_cast(expr, schema), + Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_low, - Operator::Lt, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_high, - &substrait_expr, - Operator::Lt, - extensions, - ); - - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::Or, - extensions, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_low, - &substrait_expr, - Operator::LtEq, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_high, - Operator::LtEq, - extensions, - ); - - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::And, - extensions, - )) - } + Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), + Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::ScalarSubquery(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") } - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) + Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; + Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} - Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - let mut ifs: Vec = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - r#if, - schema, - col_ref_offset, - extensions, - )?), - then: Some(to_substrait_rex( - state, - then, - schema, - col_ref_offset, - extensions, - )?), - }); - } +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> Result { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.handle_expr(x, schema)) + .collect::>>()?; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; - // Parse outer `else` - let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?)), - None => None, - }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) - } - Expr::Cast(Cast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }), - Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }), - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(state, expr, schema, col_ref_offset, extensions) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { - // function reference - let function_anchor = extensions.register_function(fun.to_string()); - // arguments - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) - .collect::>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(state, e, schema, extensions)) - .collect::>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( - state, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - col_ref_offset, - extensions, - ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new(Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} + +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> Result { + let mut arguments: Vec = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} + +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> Result { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} +pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result { + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, +) -> Result { + let BinaryExpr { left, op, right } = expr; + let l = producer.handle_expr(left, schema)?; + let r = producer.handle_expr(right, schema)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, +) -> Result { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.handle_expr(e, schema)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.handle_expr(r#if, schema)?), + then: Some(producer.handle_expr(then, schema)?), + }); + } + + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} + +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, +) -> Result { + let Cast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} + +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, +) -> Result { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} + +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> Result { + to_substrait_literal_expr(producer, value) +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, +) -> Result { + producer.handle_expr(alias.expr.as_ref(), schema) +} + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, +) -> Result { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| producer.handle_expr(e, schema)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(producer, e, schema)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + bound_type, + )) +} + +pub fn from_like( + producer: &mut impl SubstraitProducer, + like: &Like, + schema: &DFSchemaRef, +) -> Result { + let Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + } = like; + make_substrait_like_expr( + producer, + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + ) +} + +pub fn from_in_subquery( + producer: &mut impl SubstraitProducer, + subquery: &InSubquery, + schema: &DFSchemaRef, +) -> Result { + let InSubquery { + expr, + subquery, + negated, + } = subquery; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), ), - }))), - }; - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_subquery) - } - } - Expr::Not(arg) => to_substrait_unary_scalar_fn( - state, - "not", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNull(arg) => to_substrait_unary_scalar_fn( - state, - "is_null", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_null", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( - state, - "is_true", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( - state, - "is_false", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( - state, - "is_unknown", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_true", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_false", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( - state, - "is_not_unknown", - arg, - schema, - col_ref_offset, - extensions, - ), - Expr::Negative(arg) => to_substrait_unary_scalar_fn( - state, - "negate", - arg, - schema, - col_ref_offset, - extensions, - ), - _ => { - not_impl_err!("Unsupported expression: {expr:?}") - } + ), + }, + ))), + }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) } } +pub fn from_unary_expr( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + let (fn_name, arg) = match expr { + Expr::Not(arg) => ("not", arg), + Expr::IsNull(arg) => ("is_null", arg), + Expr::IsNotNull(arg) => ("is_not_null", arg), + Expr::IsTrue(arg) => ("is_true", arg), + Expr::IsFalse(arg) => ("is_false", arg), + Expr::IsUnknown(arg) => ("is_unknown", arg), + Expr::IsNotTrue(arg) => ("is_not_true", arg), + Expr::IsNotFalse(arg) => ("is_not_false", arg), + Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), + Expr::Negative(arg) => ("negate", arg), + expr => not_impl_err!("Unsupported expression: {expr:?}")?, + }; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) +} + fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 @@ -1700,7 +2000,6 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result, @@ -1709,6 +2008,7 @@ fn make_substrait_window_function( bounds: (Bound, Bound), bounds_type: BoundsType, ) -> Expression { + #[allow(deprecated)] Expression { rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { function_reference, @@ -1727,29 +2027,25 @@ fn make_substrait_window_function( } } -#[allow(deprecated)] -#[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, ignore_case: bool, negated: bool, expr: &Expr, pattern: &Expr, escape_char: Option, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { let function_anchor = if ignore_case { - extensions.register_function("ilike".to_string()) + producer.register_function("ilike".to_string()) } else { - extensions.register_function("like".to_string()) + producer.register_function("like".to_string()) }; - let expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, extensions)?; + let expr = producer.handle_expr(expr, schema)?; + let pattern = producer.handle_expr(pattern, schema)?; let escape_char = to_substrait_literal_expr( + producer, &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - extensions, )?; let arguments = vec![ FunctionArgument { @@ -1763,6 +2059,7 @@ fn make_substrait_like_expr( }, ]; + #[allow(deprecated)] let substrait_like = Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1774,8 +2071,9 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = extensions.register_function("not".to_string()); + let function_anchor = producer.register_function("not".to_string()); + #[allow(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1847,8 +2145,8 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { } fn to_substrait_literal( + producer: &mut impl SubstraitProducer, value: &ScalarValue, - extensions: &mut Extensions, ) -> Result { if value.is_null() { return Ok(Literal { @@ -2026,11 +2324,11 @@ fn to_substrait_literal( DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( - convert_array_to_literal_list(l, extensions)?, + convert_array_to_literal_list(producer, l)?, DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(l, extensions)?, + convert_array_to_literal_list(producer, l)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::Map(m) => { @@ -2047,16 +2345,16 @@ fn to_substrait_literal( let keys = (0..m.keys().len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&m.keys(), i)?, - extensions, ) }) .collect::>>()?; let values = (0..m.values().len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&m.values(), i)?, - extensions, ) }) .collect::>>()?; @@ -2082,8 +2380,8 @@ fn to_substrait_literal( .iter() .map(|col| { to_substrait_literal( + producer, &ScalarValue::try_from_array(col, 0)?, - extensions, ) }) .collect::>>()?, @@ -2104,8 +2402,8 @@ fn to_substrait_literal( } fn convert_array_to_literal_list( + producer: &mut impl SubstraitProducer, array: &GenericListArray, - extensions: &mut Extensions, ) -> Result { assert_eq!(array.len(), 1); let nested_array = array.value(0); @@ -2113,8 +2411,8 @@ fn convert_array_to_literal_list( let values = (0..nested_array.len()) .map(|i| { to_substrait_literal( + producer, &ScalarValue::try_from_array(&nested_array, i)?, - extensions, ) }) .collect::>>()?; @@ -2133,10 +2431,10 @@ fn convert_array_to_literal_list( } fn to_substrait_literal_expr( + producer: &mut impl SubstraitProducer, value: &ScalarValue, - extensions: &mut Extensions, ) -> Result { - let literal = to_substrait_literal(value, extensions)?; + let literal = to_substrait_literal(producer, value)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -2144,16 +2442,13 @@ fn to_substrait_literal_expr( /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result { - let function_anchor = extensions.register_function(fn_name.to_string()); - let substrait_expr = - to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?; + let function_anchor = producer.register_function(fn_name.to_string()); + let substrait_expr = producer.handle_expr(arg, schema)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2194,17 +2489,16 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( - state: &dyn SubstraitPlanningState, - sort: &Sort, + producer: &mut impl SubstraitProducer, + sort: &SortExpr, schema: &DFSchemaRef, - extensions: &mut Extensions, ) -> Result { - let Sort { + let SortExpr { expr, asc, nulls_first, } = sort; - let e = to_substrait_rex(state, expr, schema, 0, extensions)?; + let e = producer.handle_expr(expr, schema)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2380,9 +2674,9 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - - let mut extensions = Extensions::default(); - let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; let roundtrip_scalar = from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs deleted file mode 100644 index 0bd749c1105d..000000000000 --- a/datafusion/substrait/src/logical_plan/state.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use datafusion::{ - catalog::TableProvider, - error::{DataFusionError, Result}, - execution::{registry::SerializerRegistry, FunctionRegistry, SessionState}, - sql::TableReference, -}; - -/// This trait provides the context needed to transform a substrait plan into a -/// [`datafusion::logical_expr::LogicalPlan`] (via [`super::consumer::from_substrait_plan`]) -/// and back again into a substrait plan (via [`super::producer::to_substrait_plan`]). -/// -/// The context is declared as a trait to decouple the substrait plan encoder / -/// decoder from the [`SessionState`], potentially allowing users to define -/// their own slimmer context just for serializing and deserializing substrait. -/// -/// [`SessionState`] implements this trait. -#[async_trait] -pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry { - /// Return [SerializerRegistry] for extensions - fn serializer_registry(&self) -> &Arc; - - async fn table( - &self, - reference: &TableReference, - ) -> Result>>; -} - -#[async_trait] -impl SubstraitPlanningState for SessionState { - fn serializer_registry(&self) -> &Arc { - self.serializer_registry() - } - - async fn table( - &self, - reference: &TableReference, - ) -> Result>, DataFusionError> { - let table = reference.table().to_string(); - let schema = self.schema_for_ref(reference.clone())?; - let table_provider = schema.table(&table).await?; - Ok(table_provider) - } -} diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 383fe44be507..7045729493b1 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -571,6 +571,21 @@ async fn roundtrip_self_implicit_cross_join() -> Result<()> { roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await } +#[tokio::test] +async fn self_join_introduces_aliases() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: left.b, right.c\ + \n Inner Join: left.b = right.b\ + \n SubqueryAlias: left\ + \n TableScan: data projection=[b]\ + \n SubqueryAlias: right\ + \n TableScan: data projection=[b, c]", + false, + ) + .await +} + #[tokio::test] async fn roundtrip_arithmetic_ops() -> Result<()> { roundtrip("SELECT a - a FROM data").await?;