Skip to content

Commit

Permalink
Implement TPCH substrait integration teset, support tpch_2 (apache#11234
Browse files Browse the repository at this point in the history
)

* integrate tpch query 2

avoid cloning

optimize code

optimize code

* optimize code

* refactor code

* format
  • Loading branch information
Lordworms authored Jul 5, 2024
1 parent 351e5f9 commit 0d2525e
Show file tree
Hide file tree
Showing 8 changed files with 1,769 additions and 58 deletions.
142 changes: 94 additions & 48 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,38 @@
// specific language governing permissions and limitations
// under the License.

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use async_recursion::async_recursion;
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::plan_err;
use datafusion::common::{
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, substrait_datafusion_err,
substrait_err, DFSchema, DFSchemaRef,
};
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use url::Url;

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr::{InSubquery, Sort};

use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values,
};
use url::Url;

use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
Expand All @@ -46,10 +59,15 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
Expand All @@ -70,24 +88,6 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};

pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
Expand Down Expand Up @@ -1125,17 +1125,32 @@ pub async fn from_substrait_rex(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
if args.len() != 2 {
if f.arguments.len() < 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
f.arguments.len()
);
}
// Some expressions are binary in DataFusion but take in a variadic number of args in Substrait.
// In those cases we iterate through all the arguments, applying the binary expression against them all
let combined_expr = args
.into_iter()
.fold(None, |combined_expr: Option<Arc<Expr>>, arg: Expr| {
Some(match combined_expr {
Some(expr) => Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
Arc::try_unwrap(expr)
.unwrap_or_else(|arc: Arc<Expr>| (*arc).clone()),
), // Avoid cloning if possible
op: op.clone(),
right: Box::new(arg),
})),
None => Arc::new(arg),
})
})
.unwrap();

Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(args[0].to_owned()),
op,
right: Box::new(args[1].to_owned()),
})))
Ok(combined_expr)
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(ctx, f, input_schema, extensions).await
} else {
Expand Down Expand Up @@ -1269,7 +1284,22 @@ pub async fn from_substrait_rex(
}
}
}
_ => substrait_err!("Subquery type not implemented"),
SubqueryType::Scalar(query) => {
let plan = from_substrait_rel(
ctx,
&(query.input.clone()).unwrap_or_default(),
extensions,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Arc::new(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
})))
}
other_type => {
substrait_err!("Subquery type {:?} not implemented", other_type)
}
},
None => {
substrait_err!("Subquery experssion without SubqueryType is not allowed")
Expand Down Expand Up @@ -1699,6 +1729,7 @@ fn from_substrait_literal(
})) => {
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
}
Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())),
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Expand Down Expand Up @@ -1988,8 +2019,8 @@ impl BuiltinExprBuilder {
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 3 {
return substrait_err!("Expect three arguments for `{fn_name}` expr");
if f.arguments.len() != 2 && f.arguments.len() != 3 {
return substrait_err!("Expect two or three arguments for `{fn_name}` expr");
}

let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
Expand All @@ -2007,25 +2038,40 @@ impl BuiltinExprBuilder {
.await?
.as_ref()
.clone();
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr =
from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
);

// Default case: escape character is Literal(Utf8(None))
let escape_char = if f.arguments.len() == 3 {
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};

let escape_char_expr =
from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();

match escape_char_expr {
Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
// Convert Option<String> to Option<char>
escape_char_string.and_then(|s| s.chars().next())
}
_ => {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
)
}
}
} else {
None
};

Ok(Arc::new(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
escape_char,
case_insensitive,
})))
}
Expand Down
93 changes: 83 additions & 10 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,51 @@ mod tests {
use std::io::BufReader;
use substrait::proto::Plan;

async fn register_csv(
ctx: &SessionContext,
table_name: &str,
file_path: &str,
) -> Result<()> {
ctx.register_csv(table_name, file_path, CsvReadOptions::default())
.await
}

async fn create_context_tpch2() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch1() -> Result<SessionContext> {
let ctx = SessionContext::new();
register_csv(
&ctx,
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
)
.await?;
Ok(ctx)
}

#[tokio::test]
async fn tpch_test_1() -> Result<()> {
let ctx = create_context().await?;
let ctx = create_context_tpch1().await?;
let path = "tests/testdata/tpch_substrait_plans/query_1.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -56,14 +98,45 @@ mod tests {
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv(
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
CsvReadOptions::default(),
)
.await?;
Ok(ctx)
#[tokio::test]
async fn tpch_test_2() -> Result<()> {
let ctx = create_context_tpch2().await?;
let path = "tests/testdata/tpch_substrait_plans/query_2.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{:?}", plan);
assert_eq!(
plan_str,
"Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\
\n Limit: skip=0, fetch=100\
\n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\
\n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\
\n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = (<subquery>)\
\n Subquery:\
\n Aggregate: groupBy=[[]], aggr=[[MIN(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\
\n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\
\n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\
\n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\
\n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\
\n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]"
);
Ok(())
}
}
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/nation.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
n_nationkey,n_name,n_regionkey,n_comment
0,ALGERIA,0, haggle. carefully final deposits detect slyly agai
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/part.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment
1,pink powder puff,Manufacturer#1,Brand#13,SMALL PLATED COPPER,7,JUMBO PKG,901.00,ly final dependencies: slyly bold
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/partsupp.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment
1,1,1000,50.00,slyly final packages boost against the slyly regular
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/region.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
r_regionkey,r_name,r_comment
0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/supplier.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment
1,Supplier#1,123 Main St,0,555-1234,1000.00,No comments
Loading

0 comments on commit 0d2525e

Please sign in to comment.