From 41c5f4ed67fd48c27fb923c3c5cf8fd33f126ca4 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 10 Sep 2024 03:58:06 -0700 Subject: [PATCH] validate and adjust Substrait NamedTable schemas (#12223) (#12245) * fix: producer did not emit base_schema struct field for ReadRel Substrait plans are not valid without this, and it is generally useful for round trip testing * feat: include field_qualifier param for from_substrait_named_struct * feat: verify that Substrait and DataFusion agree on NamedScan schemas If the schema registered with DataFusion and the schema as given by the Substrait NamedScan do not have the same names and types, DataFusion should reject it * test: update existing substrait test + substrait validation test * added substrait_validation test * extracted useful test utilities The utils::test::TestSchemaCollector::generate_context_from_plan function can be used to dynamically generate a SessionContext from a Substrait plan, which will include the schemas for NamedTables as given in the Substrait plan. This helps us avoid the issue of DataFusion test schemas and Substrait plan schemas not being in sync. * feat: expose from_substrait_named_struct * refactor: remove unused imports * docs: add missing licenses * refactor: deal with unused code warnings * remove optional qualifier from from_substrait_named_struct * return DFSchema from from_substrait_named_struct * one must imagine clippy happy * accidental blah * loosen the validation for schemas allow cases where the Substrait schema is a subset of the DataFusion schema * minor doc tweaks * update test data to deal with case issues in tests * fix error message * improve readability of field compatability check * make TestSchemaCollector more flexible * fix doc typo Co-authored-by: Arttu * remove unecessary TODO * handle ReadRel projection on top of mismatched schema --------- Co-authored-by: Arttu --- .../substrait/src/logical_plan/consumer.rs | 131 +++++++++++- .../substrait/src/logical_plan/producer.rs | 16 +- .../tests/cases/consumer_integration.rs | 56 +++--- .../substrait/tests/cases/function_test.rs | 32 +-- .../substrait/tests/cases/logical_plans.rs | 59 ++---- datafusion/substrait/tests/cases/mod.rs | 1 + .../tests/cases/substrait_validations.rs | 151 ++++++++++++++ .../substrait/tests/substrait_integration.rs | 1 + .../test_plans/simple_select.substrait.json | 69 +++++++ .../simple_select_with_mask.substrait.json | 104 ++++++++++ .../substrait/tests/testdata/tpch/nation.csv | 2 +- .../substrait/tests/testdata/tpch/region.csv | 2 +- datafusion/substrait/tests/utils.rs | 186 ++++++++++++++++++ 13 files changed, 697 insertions(+), 113 deletions(-) create mode 100644 datafusion/substrait/tests/cases/substrait_validations.rs create mode 100644 datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json create mode 100644 datafusion/substrait/tests/utils.rs diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 01a854ffbdf2..c9d27237a49b 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -53,6 +53,8 @@ use crate::variation_const::{ }; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::dataframe::DataFrame; +use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, @@ -277,6 +279,20 @@ pub fn extract_projection( ); Ok(LogicalPlan::TableScan(scan)) } + LogicalPlan::Projection(projection) => { + // create another Projection around the Projection to handle the field masking + let fields: Vec = column_indices + .into_iter() + .map(|i| { + let (qualifier, field) = + projection.schema.qualified_field(i); + let column = + Column::new(qualifier.cloned(), field.name()); + Expr::Column(column) + }) + .collect(); + project(LogicalPlan::Projection(projection), fields) + } _ => plan_err!("unexpected plan for table"), } } @@ -640,6 +656,10 @@ pub async fn from_substrait_rel( } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Named Table") + })?; + let table_reference = match nt.names.len() { 0 => { return plan_err!("No table name found in NamedTable"); @@ -657,7 +677,13 @@ pub async fn from_substrait_rel( table: nt.names[2].clone().into(), }, }; - let t = ctx.table(table_reference).await?; + + let substrait_schema = + from_substrait_named_struct(named_struct, extensions)? + .replace_qualifier(table_reference.clone()); + + let t = ctx.table(table_reference.clone()).await?; + let t = ensure_schema_compatability(t, substrait_schema)?; let t = t.into_optimized_plan()?; extract_projection(t, &read.projection) } @@ -671,7 +697,7 @@ pub async fn from_substrait_rel( if vt.values.is_empty() { return Ok(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema, + schema: DFSchemaRef::new(schema), })); } @@ -704,7 +730,10 @@ pub async fn from_substrait_rel( }) .collect::>()?; - Ok(LogicalPlan::Values(Values { schema, values })) + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(schema), + values, + })) } Some(ReadType::LocalFiles(lf)) => { fn extract_filename(name: &str) -> Option { @@ -850,6 +879,87 @@ pub async fn from_substrait_rel( } } +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatability`] for details +/// +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn ensure_schema_compatability( + table: DataFrame, + substrait_schema: DFSchema, +) -> Result { + let df_schema = table.schema().to_owned().strip_qualifiers(); + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(table); + } + let selected_columns = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + let df_field = + df_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field)?; + Ok(col(format!("\"{}\"", df_field.name()))) + }) + .collect::>()?; + + table.select(selected_columns) +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatability( + datafusion_field: &Field, + substrait_field: &Field, +) -> Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + /// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise /// conflict with the columns from the other. /// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For @@ -1588,10 +1698,11 @@ fn next_struct_field_name( } } -fn from_substrait_named_struct( +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( base_schema: &NamedStruct, extensions: &Extensions, -) -> Result { +) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( base_schema.r#struct.as_ref().ok_or_else(|| { @@ -1603,12 +1714,12 @@ fn from_substrait_named_struct( ); if name_idx != base_schema.names.len() { return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - base_schema.names.len() - ); + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); } - Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?)) + DFSchema::try_from(Schema::new(fields?)) } fn from_substrait_bound( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index f323ae146600..a923aaf31abb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -42,8 +42,8 @@ use crate::variation_const::{ use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, + substrait_err, DFSchemaRef, ToDFSchema, }; -use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, @@ -139,19 +139,13 @@ pub fn to_substrait_rel( maintain_singular_struct: false, }); + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema, extensions)?; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(NamedStruct { - names: scan - .source - .schema() - .fields() - .iter() - .map(|f| f.name().to_owned()) - .collect(), - r#struct: None, - }), + base_schema: Some(base_schema), filter: None, best_effort_filter: None, projection, diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 0a86d27e013c..dad24559a06f 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -92,22 +92,22 @@ mod tests { let plan_str = format!("{}", plan); assert_eq!( plan_str, - "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ + "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ \n Limit: skip=0, fetch=100\ - \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ - \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ - \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ + \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.N_NAME ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ + \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_REGIONKEY = FILENAME_PLACEHOLDER_4.R_REGIONKEY AND FILENAME_PLACEHOLDER_4.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ \n Subquery:\ \n Aggregate: groupBy=[[]], aggr=[[min(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\ \n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\ - \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.N_NATIONKEY AND FILENAME_PLACEHOLDER_7.N_REGIONKEY = FILENAME_PLACEHOLDER_8.R_REGIONKEY AND FILENAME_PLACEHOLDER_8.R_NAME = CAST(Utf8(\"EUROPE\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_7 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_8 projection=[R_REGIONKEY, R_NAME, R_COMMENT]\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -115,8 +115,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]" + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[R_REGIONKEY, R_NAME, R_COMMENT]" ); Ok(()) } @@ -196,11 +196,11 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: NATION.n_name AS N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ + assert_eq!(plan_str, "Projection: NATION.N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Aggregate: groupBy=[[NATION.n_name]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: NATION.n_name, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.n_nationkey AND NATION.n_regionkey = REGION.r_regionkey AND REGION.r_name = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: NATION.N_NAME, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ @@ -210,8 +210,8 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: NATION projection=[n_nationkey, n_name, n_regionkey, n_comment]\ - \n TableScan: REGION projection=[r_regionkey, r_name, r_comment]"); + \n TableScan: NATION projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ + \n TableScan: REGION projection=[R_REGIONKEY, R_NAME, R_COMMENT]"); Ok(()) } @@ -255,19 +255,19 @@ mod tests { let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str = format!("{}", plan); - assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\ \n Limit: skip=0, fetch=20\ \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ - \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ - \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.N_NAME, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -297,20 +297,20 @@ mod tests { \n Projection: sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty) * Decimal128(Some(1000000),11,10)\ \n Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_3.ps_supplycost * CAST(FILENAME_PLACEHOLDER_3.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.n_nationkey AND FILENAME_PLACEHOLDER_5.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.N_NATIONKEY AND FILENAME_PLACEHOLDER_5.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_3 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_5 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.ps_partkey]], aggr=[[sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty)]]\ \n Projection: FILENAME_PLACEHOLDER_0.ps_partkey, FILENAME_PLACEHOLDER_0.ps_supplycost * CAST(FILENAME_PLACEHOLDER_0.ps_availqty AS Decimal128(19, 0))\ - \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.n_nationkey AND FILENAME_PLACEHOLDER_2.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.N_NATIONKEY AND FILENAME_PLACEHOLDER_2.N_NAME = CAST(Utf8(\"JAPAN\") AS Utf8)\ \n Inner Join: Filter: Boolean(true)\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_2 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -498,7 +498,7 @@ mod tests { assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.s_name AS S_NAME, FILENAME_PLACEHOLDER_0.s_address AS S_ADDRESS\ \n Sort: FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Projection: FILENAME_PLACEHOLDER_0.s_name, FILENAME_PLACEHOLDER_0.s_address\ - \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.n_nationkey AND FILENAME_PLACEHOLDER_1.n_name = CAST(Utf8(\"CANADA\") AS Utf8)\ + \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.N_NATIONKEY AND FILENAME_PLACEHOLDER_1.N_NAME = CAST(Utf8(\"CANADA\") AS Utf8)\ \n Subquery:\ \n Projection: FILENAME_PLACEHOLDER_2.ps_suppkey\ \n Filter: CAST(FILENAME_PLACEHOLDER_2.ps_partkey IN () AS Boolean) AND CAST(FILENAME_PLACEHOLDER_2.ps_availqty AS Decimal128(19, 1)) > ()\ @@ -515,7 +515,7 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_1 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } @@ -543,7 +543,7 @@ mod tests { \n Sort: count(Int64(1)) DESC NULLS FIRST, FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.s_name]], aggr=[[count(Int64(1))]]\ \n Projection: FILENAME_PLACEHOLDER_0.s_name\ - \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_name = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ + \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.N_NATIONKEY AND FILENAME_PLACEHOLDER_3.N_NAME = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ \n Subquery:\ \n Filter: FILENAME_PLACEHOLDER_4.l_orderkey = FILENAME_PLACEHOLDER_4.l_tax AND FILENAME_PLACEHOLDER_4.l_suppkey != FILENAME_PLACEHOLDER_4.l_linestatus\ \n TableScan: FILENAME_PLACEHOLDER_4 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ @@ -555,7 +555,7 @@ mod tests { \n Inner Join: Filter: Boolean(true)\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ \n TableScan: FILENAME_PLACEHOLDER_1 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ - \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT]"); Ok(()) } diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index 610caf3a81df..85809da6f3e4 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -19,40 +19,26 @@ #[cfg(test)] mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + use datafusion::common::Result; - use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use std::fs::File; - use std::io::BufReader; - use substrait::proto::Plan; #[tokio::test] async fn contains_function_test() -> Result<()> { - let ctx = create_context().await?; - - let path = "tests/testdata/contains_plan.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; let plan_str = format!("{}", plan); assert_eq!( plan_str, - "Projection: nation.b AS n_name\ - \n Filter: contains(nation.b, Utf8(\"IA\"))\ - \n TableScan: nation projection=[a, b, c, d, e, f]" + "Projection: nation.n_name\ + \n Filter: contains(nation.n_name, Utf8(\"IA\"))\ + \n TableScan: nation projection=[n_nationkey, n_name, n_regionkey, n_comment]" ); Ok(()) } - - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new()) - .await?; - Ok(ctx) - } } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f6a2b5036c80..8db2aa283d3c 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -19,13 +19,11 @@ #[cfg(test)] mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; use datafusion::common::Result; use datafusion::dataframe::DataFrame; - use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use std::fs::File; - use std::io::BufReader; - use substrait::proto::Plan; #[tokio::test] async fn scalar_function_compound_signature() -> Result<()> { @@ -35,18 +33,17 @@ mod tests { // we don't yet produce such plans. // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. - let ctx = create_context().await?; - // File generated with substrait-java's Isthmus: - // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + // ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!( format!("{}", plan), - "Projection: NOT DATA.a AS EXPR$0\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + "Projection: NOT DATA.D AS EXPR$0\ + \n TableScan: DATA projection=[D]" ); Ok(()) } @@ -61,19 +58,18 @@ mod tests { // we don't yet produce such plans. // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. - let ctx = create_context().await?; - // File generated with substrait-java's Isthmus: - // ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)" - let proto = read_json("tests/testdata/test_plans/select_window.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + // ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_window.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!( format!("{}", plan), - "Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ - \n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: DATA projection=[a, b, c, d, e, f]" + "Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA projection=[D, PART, ORD]" ); Ok(()) } @@ -83,11 +79,10 @@ mod tests { // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. // That's because implementing the non-nullability consistently is non-trivial. // This test confirms that reading a plan with non-nullable lists works as expected. - let ctx = create_context().await?; - let proto = + let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); - - let plan = from_substrait_plan(&ctx, &proto).await?; + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan); + let plan = from_substrait_plan(&ctx, &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); @@ -96,18 +91,4 @@ mod tests { Ok(()) } - - fn read_json(path: &str) -> Plan { - serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json") - } - - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) - .await?; - Ok(ctx) - } } diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9..42aa23626106 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -21,3 +21,4 @@ mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; +mod substrait_validations; diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs new file mode 100644 index 000000000000..cb1fb67fc044 --- /dev/null +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod tests { + + // verify the schema compatability validations + mod schema_compatability { + use crate::utils::test::read_json; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::catalog_common::TableReference; + use datafusion::common::{DFSchema, Result}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::collections::HashMap; + use std::sync::Arc; + + fn generate_context_with_table( + table_name: &str, + fields: Vec<(&str, DataType, bool)>, + ) -> Result { + let table_ref = TableReference::bare(table_name); + let fields: Vec<(Option, Arc)> = fields + .into_iter() + .map(|pair| { + let (field_name, data_type, nullable) = pair; + ( + Some(table_ref.clone()), + Arc::new(Field::new(field_name, data_type, nullable)), + ) + }) + .collect(); + + let df_schema = DFSchema::new_with_metadata(fields, HashMap::default())?; + + let ctx = SessionContext::new(); + ctx.register_table( + table_ref, + Arc::new(EmptyTable::new(df_schema.inner().clone())), + )?; + Ok(ctx) + } + + #[tokio::test] + async fn ensure_schema_match_exact() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // this is the exact schema of the Substrait plan + let df_schema = + vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[a, b]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the DataFusion schema { b, a, c } contains the Substrait schema { a, b } + let df_schema = vec![ + ("b", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[b, a]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset_with_mask() -> Result<()> { + let proto_plan = read_json( + "tests/testdata/test_plans/simple_select_with_mask.substrait.json", + ); + // the DataFusion schema { b, a, c, d } contains the Substrait schema { a, b, c } + let df_schema = vec![ + ("b", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ("d", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b\ + \n Projection: DATA.a, DATA.b, DATA.c\ + \n TableScan: DATA projection=[b, a, c]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_not_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the substrait plans contains a field b which is not in the schema + let df_schema = + vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + + #[tokio::test] + async fn reject_plans_with_incompatible_field_types() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + + let ctx = + generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + } +} diff --git a/datafusion/substrait/tests/substrait_integration.rs b/datafusion/substrait/tests/substrait_integration.rs index 6ce41c9de71a..eedd4da373e0 100644 --- a/datafusion/substrait/tests/substrait_integration.rs +++ b/datafusion/substrait/tests/substrait_integration.rs @@ -17,3 +17,4 @@ /// Run all tests that are found in the `cases` directory mod cases; +mod utils; diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json new file mode 100644 index 000000000000..aee27ef3b417 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json @@ -0,0 +1,69 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["a", "b"] + } + }], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json new file mode 100644 index 000000000000..774126ca3836 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json @@ -0,0 +1,104 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2, + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a", + "b", + "c" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + }, + "projection": { + "select": { + "struct_items": [ + { + "field": 0 + }, + { + "field": 1 + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "a", + "b" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch/nation.csv b/datafusion/substrait/tests/testdata/tpch/nation.csv index fdf7421467d3..a88d1c0d31e7 100644 --- a/datafusion/substrait/tests/testdata/tpch/nation.csv +++ b/datafusion/substrait/tests/testdata/tpch/nation.csv @@ -1,2 +1,2 @@ -n_nationkey,n_name,n_regionkey,n_comment +N_NATIONKEY,N_NAME,N_REGIONKEY,N_COMMENT 0,ALGERIA,0, haggle. carefully final deposits detect slyly agai \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/region.csv b/datafusion/substrait/tests/testdata/tpch/region.csv index 6c3fb4524355..d29c39ab8543 100644 --- a/datafusion/substrait/tests/testdata/tpch/region.csv +++ b/datafusion/substrait/tests/testdata/tpch/region.csv @@ -1,2 +1,2 @@ -r_regionkey,r_name,r_comment +R_REGIONKEY,R_NAME,R_COMMENT 0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs new file mode 100644 index 000000000000..685e3deec581 --- /dev/null +++ b/datafusion/substrait/tests/utils.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + use datafusion::catalog_common::TableReference; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::TableProvider; + use datafusion::prelude::SessionContext; + use datafusion_substrait::extensions::Extensions; + use datafusion_substrait::logical_plan::consumer::from_substrait_named_struct; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::read_rel::{NamedTable, ReadType}; + use substrait::proto::rel::RelType; + use substrait::proto::{Plan, ReadRel, Rel}; + + pub fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + + pub fn add_plan_schemas_to_ctx(ctx: SessionContext, plan: &Plan) -> SessionContext { + let schemas = TestSchemaCollector::collect_schemas(plan); + for (table_reference, table) in schemas { + ctx.register_table(table_reference, table) + .expect("Failed to register table"); + } + ctx + } + + pub struct TestSchemaCollector { + schemas: Vec<(TableReference, Arc)>, + } + + impl TestSchemaCollector { + fn new() -> Self { + TestSchemaCollector { + schemas: Vec::new(), + } + } + + fn collect_schemas(plan: &Plan) -> Vec<(TableReference, Arc)> { + let mut schema_collector = Self::new(); + + for plan_rel in plan.relations.iter() { + match plan_rel + .rel_type + .as_ref() + .expect("PlanRel must set rel_type") + { + substrait::proto::plan_rel::RelType::Rel(r) => { + schema_collector.collect_schemas_from_rel(r) + } + substrait::proto::plan_rel::RelType::Root(r) => schema_collector + .collect_schemas_from_rel( + r.input.as_ref().expect("RelRoot must set input"), + ), + } + } + schema_collector.schemas + } + + fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) { + let table_reference = match nt.names.len() { + 0 => { + panic!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + let substrait_schema = read + .base_schema + .as_ref() + .expect("No base schema found for NamedTable"); + let empty_extensions = Extensions { + functions: Default::default(), + types: Default::default(), + type_variations: Default::default(), + }; + + let df_schema = + from_substrait_named_struct(substrait_schema, &empty_extensions) + .expect( + "Unable to generate DataFusion schema from Substrait NamedStruct", + ) + .replace_qualifier(table_reference.clone()); + + let table = EmptyTable::new(df_schema.inner().clone()); + self.schemas.push((table_reference, Arc::new(table))); + } + + fn collect_schemas_from_rel(&mut self, rel: &Rel) { + match rel.rel_type.as_ref().unwrap() { + RelType::Read(r) => match r.read_type.as_ref().unwrap() { + // Virtual Tables do not contribute to the schema + ReadType::VirtualTable(_) => (), + ReadType::LocalFiles(_) => todo!(), + ReadType::NamedTable(nt) => self.collect_named_table(r, nt), + ReadType::ExtensionTable(_) => todo!(), + }, + RelType::Filter(f) => self.apply(f.input.as_ref().map(|b| b.as_ref())), + RelType::Fetch(f) => self.apply(f.input.as_ref().map(|b| b.as_ref())), + RelType::Aggregate(a) => self.apply(a.input.as_ref().map(|b| b.as_ref())), + RelType::Sort(s) => self.apply(s.input.as_ref().map(|b| b.as_ref())), + RelType::Join(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::Project(p) => self.apply(p.input.as_ref().map(|b| b.as_ref())), + RelType::Set(s) => { + for input in s.inputs.iter() { + self.collect_schemas_from_rel(input); + } + } + RelType::ExtensionSingle(s) => { + self.apply(s.input.as_ref().map(|b| b.as_ref())) + } + RelType::ExtensionMulti(m) => { + for input in m.inputs.iter() { + self.collect_schemas_from_rel(input) + } + } + RelType::ExtensionLeaf(_) => {} + RelType::Cross(c) => { + self.apply(c.left.as_ref().map(|b| b.as_ref())); + self.apply(c.right.as_ref().map(|b| b.as_ref())); + } + // RelType::Reference(_) => {} + // RelType::Write(_) => {} + // RelType::Ddl(_) => {} + RelType::HashJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::MergeJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::NestedLoopJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref())); + self.apply(j.right.as_ref().map(|b| b.as_ref())); + } + RelType::Window(w) => self.apply(w.input.as_ref().map(|b| b.as_ref())), + RelType::Exchange(e) => self.apply(e.input.as_ref().map(|b| b.as_ref())), + RelType::Expand(e) => self.apply(e.input.as_ref().map(|b| b.as_ref())), + _ => todo!(), + } + } + + fn apply(&mut self, input: Option<&Rel>) { + match input { + None => {} + Some(rel) => self.collect_schemas_from_rel(rel), + } + } + } +}