diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 01a854ffbdf2..c9d27237a49b 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -53,6 +53,8 @@ use crate::variation_const::{ }; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::dataframe::DataFrame; +use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, @@ -277,6 +279,20 @@ pub fn extract_projection( ); Ok(LogicalPlan::TableScan(scan)) } + LogicalPlan::Projection(projection) => { + // create another Projection around the Projection to handle the field masking + let fields: Vec = column_indices + .into_iter() + .map(|i| { + let (qualifier, field) = + projection.schema.qualified_field(i); + let column = + Column::new(qualifier.cloned(), field.name()); + Expr::Column(column) + }) + .collect(); + project(LogicalPlan::Projection(projection), fields) + } _ => plan_err!("unexpected plan for table"), } } @@ -640,6 +656,10 @@ pub async fn from_substrait_rel( } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Named Table") + })?; + let table_reference = match nt.names.len() { 0 => { return plan_err!("No table name found in NamedTable"); @@ -657,7 +677,13 @@ pub async fn from_substrait_rel( table: nt.names[2].clone().into(), }, }; - let t = ctx.table(table_reference).await?; + + let substrait_schema = + from_substrait_named_struct(named_struct, extensions)? + .replace_qualifier(table_reference.clone()); + + let t = ctx.table(table_reference.clone()).await?; + let t = ensure_schema_compatability(t, substrait_schema)?; let t = t.into_optimized_plan()?; extract_projection(t, &read.projection) } @@ -671,7 +697,7 @@ pub async fn from_substrait_rel( if vt.values.is_empty() { return Ok(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema, + schema: DFSchemaRef::new(schema), })); } @@ -704,7 +730,10 @@ pub async fn from_substrait_rel( }) .collect::>()?; - Ok(LogicalPlan::Values(Values { schema, values })) + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(schema), + values, + })) } Some(ReadType::LocalFiles(lf)) => { fn extract_filename(name: &str) -> Option { @@ -850,6 +879,87 @@ pub async fn from_substrait_rel( } } +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatability`] for details +/// +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn ensure_schema_compatability( + table: DataFrame, + substrait_schema: DFSchema, +) -> Result { + let df_schema = table.schema().to_owned().strip_qualifiers(); + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(table); + } + let selected_columns = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + let df_field = + df_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field)?; + Ok(col(format!("\"{}\"", df_field.name()))) + }) + .collect::>()?; + + table.select(selected_columns) +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatability( + datafusion_field: &Field, + substrait_field: &Field, +) -> Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + /// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise /// conflict with the columns from the other. /// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For @@ -1588,10 +1698,11 @@ fn next_struct_field_name( } } -fn from_substrait_named_struct( +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( base_schema: &NamedStruct, extensions: &Extensions, -) -> Result { +) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( base_schema.r#struct.as_ref().ok_or_else(|| { @@ -1603,12 +1714,12 @@ fn from_substrait_named_struct( ); if name_idx != base_schema.names.len() { return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - base_schema.names.len() - ); + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); } - Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?)) + DFSchema::try_from(Schema::new(fields?)) } fn from_substrait_bound( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index f323ae146600..a923aaf31abb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -42,8 +42,8 @@ use crate::variation_const::{ use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, + substrait_err, DFSchemaRef, ToDFSchema, }; -use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, @@ -139,19 +139,13 @@ pub fn to_substrait_rel( maintain_singular_struct: false, }); + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema, extensions)?; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(NamedStruct { - names: scan - .source - .schema() - .fields() - .iter() - .map(|f| f.name().to_owned()) - .collect(), - r#struct: None, - }), + base_schema: Some(base_schema), filter: None, best_effort_filter: None, projection, diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 0a86d27e013c..dad24559a06f 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -92,22 +92,22 @@ mod tests { let plan_str = format!("{}", plan); assert_eq!( plan_str, - "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ + "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ \n Limit: skip=0, fetch=100\ - \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ - \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ - \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ + \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.N_NAME ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ + \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_REGIONKEY = FILENAME_PLACEHOLDER_4.R_REGIONKEY AND FILENAME_PLACEHOLDER_4.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ \n Subquery:\ \n Aggregate: groupBy=[[]], aggr=[[min(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\ \n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\ - \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.N_NATIONKEY AND FILENAME_PLACEHOLDER_7.N_REGIONKEY = FILENAME_PLACEHOLDER_8.R_REGIONKEY AND FILENAME_PLACEHOLDER_8.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_7 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_8 projection=[R_REGIONKEY, R_NAME, R_COMMENT]\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -115,8 +115,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]" + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[R_REGIONKEY, R_NAME, R_COMMENT]" ); Ok(()) } @@ -196,11 +196,11 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: NATION.n_name AS N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ + assert_eq!(plan_str, "Projection: NATION.N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Aggregate: groupBy=[[NATION.n_name]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: NATION.n_name, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.n_nationkey AND NATION.n_regionkey = REGION.r_regionkey AND REGION.r_name = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: NATION.N_NAME, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -210,8 +210,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: NATION projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: REGION projection=[r_regionkey, r_name, r_comment]"); + \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]"); Ok(()) } @@ -255,19 +255,19 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ \n Limit: skip=0, fetch=20\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -297,20 +297,20 @@ mod tests { \n Projection: sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty) * Decimal128(Some(1000000),11,10)\ \n Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_3.ps_supplycost * CAST(FILENAME_PLACEHOLDER_3.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.n_nationkey AND FILENAME_PLACEHOLDER_5.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.N_NATIONKEY AND FILENAME_PLACEHOLDER_5.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_5 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.ps_partkey]], aggr=[[sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_0.ps_partkey, FILENAME_PLACEHOLDER_0.ps_supplycost * CAST(FILENAME_PLACEHOLDER_0.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.n_nationkey AND FILENAME_PLACEHOLDER_2.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.N_NATIONKEY AND FILENAME_PLACEHOLDER_2.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_2 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -498,7 +498,7 @@ mod tests { assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.s_name AS S_NAME, FILENAME_PLACEHOLDER_0.s_address AS S_ADDRESS\ \n Sort: FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Projection: FILENAME_PLACEHOLDER_0.s_name, FILENAME_PLACEHOLDER_0.s_address\ - \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.n_nationkey AND FILENAME_PLACEHOLDER_1.n_name = CAST(Utf8(\"CANADA\") AS Utf8)\ + \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.N_NATIONKEY AND FILENAME_PLACEHOLDER_1.N_NAME = CAST(Utf8(\"CANADA\") AS Utf8)\ \n Subquery:\ \n Projection: FILENAME_PLACEHOLDER_2.ps_suppkey\ \n Filter: CAST(FILENAME_PLACEHOLDER_2.ps_partkey IN () AS Boolean) AND CAST(FILENAME_PLACEHOLDER_2.ps_availqty AS Decimal128(19, 1)) > ()\ @@ -515,7 +515,7 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_1 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -543,7 +543,7 @@ mod tests { \n Sort: count(Int64(1)) DESC NULLS FIRST, FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.s_name]], aggr=[[count(Int64(1))]]\ \n Projection: FILENAME_PLACEHOLDER_0.s_name\ - \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_name = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_NAME = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ \n Subquery:\ \n Filter: FILENAME_PLACEHOLDER_4.l_orderkey = FILENAME_PLACEHOLDER_4.l_tax AND FILENAME_PLACEHOLDER_4.l_suppkey != FILENAME_PLACEHOLDER_4.l_linestatus\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ @@ -555,7 +555,7 @@ mod tests { \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index 610caf3a81df..85809da6f3e4 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -19,40 +19,26 @@ #[cfg(test)] mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + use datafusion::common::Result; - use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use std::fs::File; - use std::io::BufReader; - use substrait::proto::Plan; #[tokio::test] async fn contains_function_test() -> Result<()> { - let ctx = create_context().await?; - - let path = "tests/testdata/contains_plan.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; let plan_str = format!("{}", plan); assert_eq!( plan_str, - "Projection: nation.b AS n_name\ - \n Filter: contains(nation.b, Utf8(\"IA\"))\ - \n TableScan: nation projection=[a, b, c, d, e, f]" + "Projection: nation.n_name\ + \n Filter: contains(nation.n_name, Utf8(\"IA\"))\ + \n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]" ); Ok(()) } - - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new()) - .await?; - Ok(ctx) - } } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f6a2b5036c80..8db2aa283d3c 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -19,13 +19,11 @@ #[cfg(test)] mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; use datafusion::common::Result; use datafusion::dataframe::DataFrame; - use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use std::fs::File; - use std::io::BufReader; - use substrait::proto::Plan; #[tokio::test] async fn scalar_function_compound_signature() -> Result<()> { @@ -35,18 +33,17 @@ mod tests { // we don't yet produce such plans. // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. - let ctx = create_context().await?; - // File generated with substrait-java's Isthmus: - // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + // ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!( format!("{}", plan), - "Projection: NOT DATA.a AS EXPR$0\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + "Projection: NOT DATA.D AS EXPR$0\ + \n TableScan: DATA projection=[D]" ); Ok(()) } @@ -61,19 +58,18 @@ mod tests { // we don't yet produce such plans. // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. - let ctx = create_context().await?; - // File generated with substrait-java's Isthmus: - // ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)" - let proto = read_json("tests/testdata/test_plans/select_window.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + // ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_window.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!( format!("{}", plan), - "Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ - \n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + "Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA projection=[D, PART, ORD]" ); Ok(()) } @@ -83,11 +79,10 @@ mod tests { // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. // That's because implementing the non-nullability consistently is non-trivial. // This test confirms that reading a plan with non-nullable lists works as expected. - let ctx = create_context().await?; - let proto = + let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); @@ -96,18 +91,4 @@ mod tests { Ok(()) } - - fn read_json(path: &str) -> Plan { - serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json") - } - - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) - .await?; - Ok(ctx) - } } diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9..42aa23626106 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -21,3 +21,4 @@ mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; +mod substrait_validations; diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs new file mode 100644 index 000000000000..cb1fb67fc044 --- /dev/null +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod tests { + + // verify the schema compatability validations + mod schema_compatability { + use crate::utils::test::read_json; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::catalog_common::TableReference; + use datafusion::common::{DFSchema, Result}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::collections::HashMap; + use std::sync::Arc; + + fn generate_context_with_table( + table_name: &str, + fields: Vec<(&str, DataType, bool)>, + ) -> Result { + let table_ref = TableReference::bare(table_name); + let fields: Vec<(Option, Arc)> = fields + .into_iter() + .map(|pair| { + let (field_name, data_type, nullable) = pair; + ( + Some(table_ref.clone()), + Arc::new(Field::new(field_name, data_type, nullable)), + ) + }) + .collect(); + + let df_schema = DFSchema::new_with_metadata(fields, HashMap::default())?; + + let ctx = SessionContext::new(); + ctx.register_table( + table_ref, + Arc::new(EmptyTable::new(df_schema.inner().clone())), + )?; + Ok(ctx) + } + + #[tokio::test] + async fn ensure_schema_match_exact() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // this is the exact schema of the Substrait plan + let df_schema = + vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[a, b]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the DataFusion schema { b, a, c } contains the Substrait schema { a, b } + let df_schema = vec![ + ("b", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[b, a]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset_with_mask() -> Result<()> { + let proto_plan = read_json( + "tests/testdata/test_plans/simple_select_with_mask.substrait.json", + ); + // the DataFusion schema { b, a, c, d } contains the Substrait schema { a, b, c } + let df_schema = vec![ + ("b", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ("d", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b, DATA.c\ + \n TableScan: DATA projection=[b, a, c]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_not_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the substrait plans contains a field b which is not in the schema + let df_schema = + vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + + #[tokio::test] + async fn reject_plans_with_incompatible_field_types() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + + let ctx = + generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + } +} diff --git a/datafusion/substrait/tests/substrait_integration.rs b/datafusion/substrait/tests/substrait_integration.rs index 6ce41c9de71a..eedd4da373e0 100644 --- a/datafusion/substrait/tests/substrait_integration.rs +++ b/datafusion/substrait/tests/substrait_integration.rs @@ -17,3 +17,4 @@ /// Run all tests that are found in the `cases` directory mod cases; +mod utils; diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json new file mode 100644 index 000000000000..aee27ef3b417 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json @@ -0,0 +1,69 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["a", "b"] + } + }], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json new file mode 100644 index 000000000000..774126ca3836 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json @@ -0,0 +1,104 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2, + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a", + "b", + "c" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + }, + "projection": { + "select": { + "struct_items": [ + { + "field": 0 + }, + { + "field": 1 + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "a", + "b" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch/nation.csv b/datafusion/substrait/tests/testdata/tpch/nation.csv index fdf7421467d3..a88d1c0d31e7 100644 --- a/datafusion/substrait/tests/testdata/tpch/nation.csv +++ b/datafusion/substrait/tests/testdata/tpch/nation.csv @@ -1,2 +1,2 @@ -n_nationkey,n_name,n_regionkey,n_comment +N_NATIONKEY,N_NAME,N_REGIONKEY,N_COMMENT 0,ALGERIA,0, haggle. carefully final deposits detect slyly agai \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/region.csv b/datafusion/substrait/tests/testdata/tpch/region.csv index 6c3fb4524355..d29c39ab8543 100644 --- a/datafusion/substrait/tests/testdata/tpch/region.csv +++ b/datafusion/substrait/tests/testdata/tpch/region.csv @@ -1,2 +1,2 @@ -r_regionkey,r_name,r_comment +R_REGIONKEY,R_NAME,R_COMMENT 0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs new file mode 100644 index 000000000000..685e3deec581 --- /dev/null +++ b/datafusion/substrait/tests/utils.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + use datafusion::catalog_common::TableReference; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::TableProvider; + use datafusion::prelude::SessionContext; + use datafusion_substrait::extensions::Extensions; + use datafusion_substrait::logical_plan::consumer::from_substrait_named_struct; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::read_rel::{NamedTable, ReadType}; + use substrait::proto::rel::RelType; + use substrait::proto::{Plan, ReadRel, Rel}; + + pub fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + + pub fn add_plan_schemas_to_ctx(ctx: SessionContext, plan: &Plan) -> SessionContext { + let schemas = TestSchemaCollector::collect_schemas(plan); + for (table_reference, table) in schemas { + ctx.register_table(table_reference, table) + .expect("Failed to register table"); + } + ctx + } + + pub struct TestSchemaCollector { + schemas: Vec<(TableReference, Arc)>, + } + + impl TestSchemaCollector { + fn new() -> Self { + TestSchemaCollector { + schemas: Vec::new(), + } + } + + fn collect_schemas(plan: &Plan) -> Vec<(TableReference, Arc)> { + let mut schema_collector = Self::new(); + + for plan_rel in plan.relations.iter() { + match plan_rel + .rel_type + .as_ref() + .expect("PlanRel must set rel_type") + { + substrait::proto::plan_rel::RelType::Rel(r) => { + schema_collector.collect_schemas_from_rel(r) + } + substrait::proto::plan_rel::RelType::Root(r) => schema_collector + .collect_schemas_from_rel( + r.input.as_ref().expect("RelRoot must set input"), + ), + } + } + schema_collector.schemas + } + + fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) { + let table_reference = match nt.names.len() { + 0 => { + panic!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + let substrait_schema = read + .base_schema + .as_ref() + .expect("No base schema found for NamedTable"); + let empty_extensions = Extensions { + functions: Default::default(), + types: Default::default(), + type_variations: Default::default(), + }; + + let df_schema = + from_substrait_named_struct(substrait_schema, &empty_extensions) + .expect( + "Unable to generate DataFusion schema from Substrait NamedStruct", + ) + .replace_qualifier(table_reference.clone()); + + let table = EmptyTable::new(df_schema.inner().clone()); + self.schemas.push((table_reference, Arc::new(table))); + } + + fn collect_schemas_from_rel(&mut self, rel: &Rel) { + match rel.rel_type.as_ref().unwrap() { + RelType::Read(r) => match r.read_type.as_ref().unwrap() { + // Virtual Tables do not contribute to the schema + ReadType::VirtualTable(_) => (), + ReadType::LocalFiles(_) => todo!(), + ReadType::NamedTable(nt) => self.collect_named_table(r, nt), + ReadType::ExtensionTable(_) => todo!(), + }, + RelType::Filter(f) => self.apply(f.input.as_ref().map(|b| b.as_ref())), + RelType::Fetch(f) => self.apply(f.input.as_ref().map(|b| b.as_ref())), + RelType::Aggregate(a) => self.apply(a.input.as_ref().map(|b| b.as_ref())), + RelType::Sort(s) => self.apply(s.input.as_ref().map(|b| b.as_ref())), + RelType::Join(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::Project(p) => self.apply(p.input.as_ref().map(|b| b.as_ref())), + RelType::Set(s) => { + for input in s.inputs.iter() { + self.collect_schemas_from_rel(input); + } + } + RelType::ExtensionSingle(s) => { + self.apply(s.input.as_ref().map(|b| b.as_ref())) + } + RelType::ExtensionMulti(m) => { + for input in m.inputs.iter() { + self.collect_schemas_from_rel(input) + } + } + RelType::ExtensionLeaf(_) => {} + RelType::Cross(c) => { + self.apply(c.left.as_ref().map(|b| b.as_ref())); + self.apply(c.right.as_ref().map(|b| b.as_ref())); + } + // RelType::Reference(_) => {} + // RelType::Write(_) => {} + // RelType::Ddl(_) => {} + RelType::HashJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::MergeJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::NestedLoopJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::Window(w) => self.apply(w.input.as_ref().map(|b| b.as_ref())), + RelType::Exchange(e) => self.apply(e.input.as_ref().map(|b| b.as_ref())), + RelType::Expand(e) => self.apply(e.input.as_ref().map(|b| b.as_ref())), + _ => todo!(), + } + } + + fn apply(&mut self, input: Option<&Rel>) { + match input { + None => {} + Some(rel) => self.collect_schemas_from_rel(rel), + } + } + } +}