From efb3a7db988247b5b3fa9836977cf50c941027bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 9 Jun 2023 15:50:54 +0200 Subject: [PATCH] Update DataFusion to 26 (#798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enable all tests * Adapt * Update DataFusion to 26 * Add physical_round_trip test * Fmt * Add cfg again * Do not enable q15 just yet * fmt * Fix * Fix * Fix * Fix * Fix * Schema fix * Undo some --------- Co-authored-by: Daniël Heres --- Cargo.toml | 13 +-- ballista-cli/Cargo.toml | 2 +- ballista-cli/src/exec.rs | 6 +- ballista/client/src/context.rs | 44 +++++----- ballista/scheduler/src/test_utils.rs | 48 +++++------ benchmarks/Cargo.toml | 1 + benchmarks/src/bin/tpch.rs | 117 ++++++++++++++++----------- 7 files changed, 129 insertions(+), 102 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 75986d3e4..c47dfdce2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,15 +19,16 @@ members = ["ballista-cli", "ballista/client", "ballista/core", "ballista/executor", "ballista/scheduler", "benchmarks", "examples"] [workspace.dependencies] -arrow = { version = "39.0.0" } -arrow-flight = { version = "39.0.0", features = ["flight-sql-experimental"] } +arrow = { version = "40.0.0" } +arrow-flight = { version = "40.0.0", features = ["flight-sql-experimental"] } +arrow-schema = { version = "40.0.0", default-features = false } configure_me = { version = "0.4.0" } configure_me_codegen = { version = "0.4.4" } -datafusion = "25.0.0" -datafusion-cli = "25.0.0" -datafusion-proto = "25.0.0" +datafusion = "26.0.0" +datafusion-cli = "26.0.0" +datafusion-proto = "26.0.0" object_store = "0.5.6" -sqlparser = "0.33.0" +sqlparser = "0.34.0" tonic = { version = "0.9" } tonic-build = { version = "0.9", default-features = false, features = ["transport", "prost"] } tracing = "0.1.36" diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml index c805f3d13..f293dc574 100644 --- a/ballista-cli/Cargo.toml +++ b/ballista-cli/Cargo.toml @@ -37,7 +37,7 @@ dirs = "4.0.0" env_logger = "0.10" mimalloc = { version = "0.1", default-features = false } num_cpus = "1.13.0" -rustyline = "10.0" +rustyline = "11.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } [features] diff --git a/ballista-cli/src/exec.rs b/ballista-cli/src/exec.rs index 4311cf2da..bd075b41c 100644 --- a/ballista-cli/src/exec.rs +++ b/ballista-cli/src/exec.rs @@ -90,7 +90,7 @@ pub async fn exec_from_files( /// run and execute SQL statements and commands against a context with the given print options pub async fn exec_from_repl(ctx: &BallistaContext, print_options: &mut PrintOptions) { - let mut rl = Editor::::new().expect("created editor"); + let mut rl = Editor::new().expect("created editor"); rl.set_helper(Some(CliHelper::default())); rl.load_history(".history").ok(); @@ -99,7 +99,7 @@ pub async fn exec_from_repl(ctx: &BallistaContext, print_options: &mut PrintOpti loop { match rl.readline("❯ ") { Ok(line) if line.starts_with('\\') => { - rl.add_history_entry(line.trim_end()); + rl.add_history_entry(line.trim_end()).unwrap(); let command = line.split_whitespace().collect::>().join(" "); if let Ok(cmd) = &command[1..].parse::() { match cmd { @@ -133,7 +133,7 @@ pub async fn exec_from_repl(ctx: &BallistaContext, print_options: &mut PrintOpti } } Ok(line) => { - rl.add_history_entry(line.trim_end()); + rl.add_history_entry(line.trim_end()).unwrap(); match exec_and_print(ctx, &print_options, line).await { Ok(_) => {} Err(err) => eprintln!("{err:?}"), diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs index 94171a6da..f4885ccb0 100644 --- a/ballista/client/src/context.rs +++ b/ballista/client/src/context.rs @@ -613,7 +613,7 @@ mod tests { table_partition_cols: x.table_partition_cols.clone(), collect_stat: x.collect_stat, target_partitions: x.target_partitions, - file_sort_order: None, + file_sort_order: vec![], infinite_source: false, }; @@ -810,11 +810,11 @@ mod tests { .unwrap(); let res = df.collect().await.unwrap(); let expected = vec![ - "+-------------------------+", - "| APPROXDISTINCT(test.id) |", - "+-------------------------+", - "| 8 |", - "+-------------------------+", + "+--------------------------+", + "| APPROX_DISTINCT(test.id) |", + "+--------------------------+", + "| 8 |", + "+--------------------------+", ]; assert_result_eq(expected, &res); @@ -825,7 +825,7 @@ mod tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------------+", - "| ARRAYAGG(test.id) |", + "| ARRAY_AGG(test.id) |", "+--------------------------+", "| [4, 5, 6, 7, 2, 3, 0, 1] |", "+--------------------------+", @@ -849,11 +849,11 @@ mod tests { .unwrap(); let res = df.collect().await.unwrap(); let expected = vec![ - "+----------------------+", - "| VARIANCEPOP(test.id) |", - "+----------------------+", - "| 5.250000000000001 |", - "+----------------------+", + "+-----------------------+", + "| VARIANCE_POP(test.id) |", + "+-----------------------+", + "| 5.250000000000001 |", + "+-----------------------+", ]; assert_result_eq(expected, &res); @@ -933,11 +933,11 @@ mod tests { .unwrap(); let res = df.collect().await.unwrap(); let expected = vec![ - "+---------------------------------------------------------------+", - "| APPROXPERCENTILECONTWITHWEIGHT(test.id,Int64(2),Float64(0.5)) |", - "+---------------------------------------------------------------+", - "| 1 |", - "+---------------------------------------------------------------+", + "+-------------------------------------------------------------------+", + "| APPROX_PERCENTILE_CONT_WITH_WEIGHT(test.id,Int64(2),Float64(0.5)) |", + "+-------------------------------------------------------------------+", + "| 1 |", + "+-------------------------------------------------------------------+", ]; assert_result_eq(expected, &res); @@ -947,11 +947,11 @@ mod tests { .unwrap(); let res = df.collect().await.unwrap(); let expected = vec![ - "+----------------------------------------------------+", - "| APPROXPERCENTILECONT(test.double_col,Float64(0.5)) |", - "+----------------------------------------------------+", - "| 7.574999999999999 |", - "+----------------------------------------------------+", + "+------------------------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.double_col,Float64(0.5)) |", + "+------------------------------------------------------+", + "| 7.574999999999999 |", + "+------------------------------------------------------+", ]; assert_result_eq(expected, &res); diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index beaabca4c..067989039 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -151,51 +151,51 @@ pub fn get_tpch_schema(table: &str) -> Schema { match table { "part" => Schema::new(vec![ - Field::new("p_partkey", DataType::Int32, false), + Field::new("p_partkey", DataType::Int64, false), Field::new("p_name", DataType::Utf8, false), Field::new("p_mfgr", DataType::Utf8, false), Field::new("p_brand", DataType::Utf8, false), Field::new("p_type", DataType::Utf8, false), Field::new("p_size", DataType::Int32, false), Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Float64, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), Field::new("p_comment", DataType::Utf8, false), ]), "supplier" => Schema::new(vec![ - Field::new("s_suppkey", DataType::Int32, false), + Field::new("s_suppkey", DataType::Int64, false), Field::new("s_name", DataType::Utf8, false), Field::new("s_address", DataType::Utf8, false), - Field::new("s_nationkey", DataType::Int32, false), + Field::new("s_nationkey", DataType::Int64, false), Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Float64, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), Field::new("s_comment", DataType::Utf8, false), ]), "partsupp" => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int32, false), - Field::new("ps_suppkey", DataType::Int32, false), + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Float64, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), Field::new("ps_comment", DataType::Utf8, false), ]), "customer" => Schema::new(vec![ - Field::new("c_custkey", DataType::Int32, false), + Field::new("c_custkey", DataType::Int64, false), Field::new("c_name", DataType::Utf8, false), Field::new("c_address", DataType::Utf8, false), - Field::new("c_nationkey", DataType::Int32, false), + Field::new("c_nationkey", DataType::Int64, false), Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), Field::new("c_mktsegment", DataType::Utf8, false), Field::new("c_comment", DataType::Utf8, false), ]), "orders" => Schema::new(vec![ - Field::new("o_orderkey", DataType::Int32, false), - Field::new("o_custkey", DataType::Int32, false), + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), Field::new("o_orderdate", DataType::Date32, false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), @@ -204,14 +204,14 @@ pub fn get_tpch_schema(table: &str) -> Schema { ]), "lineitem" => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int32, false), - Field::new("l_partkey", DataType::Int32, false), - Field::new("l_suppkey", DataType::Int32, false), + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), - Field::new("l_extendedprice", DataType::Float64, false), - Field::new("l_discount", DataType::Float64, false), - Field::new("l_tax", DataType::Float64, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), Field::new("l_shipdate", DataType::Date32, false), @@ -223,14 +223,14 @@ pub fn get_tpch_schema(table: &str) -> Schema { ]), "nation" => Schema::new(vec![ - Field::new("n_nationkey", DataType::Int32, false), + Field::new("n_nationkey", DataType::Int64, false), Field::new("n_name", DataType::Utf8, false), - Field::new("n_regionkey", DataType::Int32, false), + Field::new("n_regionkey", DataType::Int64, false), Field::new("n_comment", DataType::Utf8, false), ]), "region" => Schema::new(vec![ - Field::new("r_regionkey", DataType::Int32, false), + Field::new("r_regionkey", DataType::Int64, false), Field::new("r_name", DataType::Utf8, false), Field::new("r_comment", DataType::Utf8, false), ]), diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 53ad3d136..90c99bb29 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -34,6 +34,7 @@ simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] +arrow-schema = { workspace = true } ballista = { path = "../ballista/client", version = "0.11.0" } datafusion = { workspace = true } datafusion-proto = { workspace = true } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index c0cfe0c58..1a58714f8 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -17,6 +17,7 @@ //! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. +use arrow_schema::SchemaBuilder; use ballista::context::BallistaContext; use ballista::prelude::{ BallistaConfig, BALLISTA_COLLECT_STATISTICS, BALLISTA_DEFAULT_BATCH_SIZE, @@ -574,7 +575,7 @@ async fn register_tables( // dbgen creates .tbl ('|' delimited) files without header "tbl" => { let path = find_path(path, table, "tbl")?; - let schema = get_schema(table); + let schema = get_tbl_tpch_table_schema(table); let options = CsvReadOptions::new() .schema(&schema) .delimiter(b'|') @@ -843,7 +844,7 @@ async fn get_table( target_partitions, collect_stat: true, table_partition_cols: vec![], - file_sort_order: None, + file_sort_order: vec![], infinite_source: false, }; @@ -954,6 +955,13 @@ fn get_schema(table: &str) -> Schema { } } +/// The `.tbl` file contains a trailing column +pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { + let mut schema = SchemaBuilder::from(get_schema(table).fields); + schema.push(Field::new("__placeholder", DataType::Utf8, false)); + schema.finish() +} + #[derive(Debug, Serialize)] struct BenchmarkRun { /// Benchmark crate version @@ -1030,13 +1038,26 @@ async fn get_expected_results(n: usize, path: &str) -> Result> .fields() .iter() .map(|field| { - Expr::Alias( - Box::new(Expr::Cast(Cast { - expr: Box::new(trim(col(Field::name(field)))), - data_type: Field::data_type(field).to_owned(), - })), - Field::name(field).to_string(), - ) + match Field::data_type(field) { + DataType::Decimal128(_, _) => { + // there's no support for casting from Utf8 to Decimal, so + // we'll cast from Utf8 to Float64 to Decimal for Decimal types + let inner_cast = Box::new(Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + DataType::Float64, + ))); + Expr::Cast(Cast::new( + inner_cast, + Field::data_type(field).to_owned(), + )) + .alias(Field::name(field)) + } + _ => Expr::Cast(Cast::new( + Box::new(trim(col(Field::name(field)))), + Field::data_type(field).to_owned(), + )) + .alias(Field::name(field)), + } }) .collect::>(), )?; @@ -1079,13 +1100,13 @@ fn get_answer_schema(n: usize) -> Schema { 1 => Schema::new(vec![ Field::new("l_returnflag", DataType::Utf8, true), Field::new("l_linestatus", DataType::Utf8, true), - Field::new("sum_qty", DataType::Float64, true), - Field::new("sum_base_price", DataType::Float64, true), - Field::new("sum_disc_price", DataType::Float64, true), - Field::new("sum_charge", DataType::Float64, true), - Field::new("avg_qty", DataType::Float64, true), - Field::new("avg_price", DataType::Decimal128(19, 6), false), //TODO should be precision 2 - Field::new("avg_disc", DataType::Float64, true), + Field::new("sum_qty", DataType::Decimal128(15, 2), true), + Field::new("sum_base_price", DataType::Decimal128(15, 2), true), + Field::new("sum_disc_price", DataType::Decimal128(15, 2), true), + Field::new("sum_charge", DataType::Decimal128(15, 2), true), + Field::new("avg_qty", DataType::Decimal128(15, 2), true), + Field::new("avg_price", DataType::Decimal128(15, 2), true), + Field::new("avg_disc", DataType::Decimal128(15, 2), true), Field::new("count_order", DataType::Int64, true), ]), @@ -1093,7 +1114,7 @@ fn get_answer_schema(n: usize) -> Schema { Field::new("s_acctbal", DataType::Decimal128(15, 2), true), Field::new("s_name", DataType::Utf8, true), Field::new("n_name", DataType::Utf8, true), - Field::new("p_partkey", DataType::Int32, true), + Field::new("p_partkey", DataType::Int64, true), Field::new("p_mfgr", DataType::Utf8, true), Field::new("s_address", DataType::Utf8, true), Field::new("s_phone", DataType::Utf8, true), @@ -1101,8 +1122,8 @@ fn get_answer_schema(n: usize) -> Schema { ]), 3 => Schema::new(vec![ - Field::new("l_orderkey", DataType::Int32, true), - Field::new("revenue", DataType::Decimal128(19, 6), true), //TODO should be precision 2 + Field::new("l_orderkey", DataType::Int64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), Field::new("o_orderdate", DataType::Date32, true), Field::new("o_shippriority", DataType::Int32, true), ]), @@ -1114,37 +1135,37 @@ fn get_answer_schema(n: usize) -> Schema { 5 => Schema::new(vec![ Field::new("n_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(38, 4), true), //TODO should be precision 2 + Field::new("revenue", DataType::Decimal128(15, 2), true), ]), 6 => Schema::new(vec![Field::new( "revenue", - DataType::Decimal128(25, 2), + DataType::Decimal128(15, 2), true, )]), 7 => Schema::new(vec![ Field::new("supp_nation", DataType::Utf8, true), Field::new("cust_nation", DataType::Utf8, true), - Field::new("l_year", DataType::Int32, true), - Field::new("revenue", DataType::Decimal128(38, 4), true), //TODO should be precision 2 + Field::new("l_year", DataType::Float64, true), + Field::new("revenue", DataType::Decimal128(15, 2), true), ]), 8 => Schema::new(vec![ - Field::new("o_year", DataType::Int32, true), - Field::new("mkt_share", DataType::Decimal128(38, 4), true), //TODO should be precision 2 + Field::new("o_year", DataType::Float64, true), + Field::new("mkt_share", DataType::Decimal128(15, 2), true), ]), 9 => Schema::new(vec![ Field::new("nation", DataType::Utf8, true), - Field::new("o_year", DataType::Int32, true), - Field::new("sum_profit", DataType::Decimal128(38, 4), true), //TODO should be precision 2 + Field::new("o_year", DataType::Float64, true), + Field::new("sum_profit", DataType::Decimal128(15, 2), true), ]), 10 => Schema::new(vec![ - Field::new("c_custkey", DataType::Int32, true), + Field::new("c_custkey", DataType::Int64, true), Field::new("c_name", DataType::Utf8, true), - Field::new("revenue", DataType::Decimal128(38, 4), true), //TODO should be precision 2 + Field::new("revenue", DataType::Decimal128(15, 2), true), Field::new("c_acctbal", DataType::Decimal128(15, 2), true), Field::new("n_name", DataType::Utf8, true), Field::new("c_address", DataType::Utf8, true), @@ -1153,8 +1174,8 @@ fn get_answer_schema(n: usize) -> Schema { ]), 11 => Schema::new(vec![ - Field::new("ps_partkey", DataType::Int32, true), - Field::new("value", DataType::Decimal128(36, 2), true), + Field::new("ps_partkey", DataType::Int64, true), + Field::new("value", DataType::Decimal128(15, 2), true), ]), 12 => Schema::new(vec![ @@ -1168,22 +1189,24 @@ fn get_answer_schema(n: usize) -> Schema { Field::new("custdist", DataType::Int64, true), ]), - 14 => Schema::new(vec![ - Field::new("promo_revenue", DataType::Decimal128(38, 38), true), //TODO should be precision 2 - ]), + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), - 15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + 15 => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("total_revenue", DataType::Decimal128(15, 2), true), + ]), 16 => Schema::new(vec![ Field::new("p_brand", DataType::Utf8, true), Field::new("p_type", DataType::Utf8, true), - Field::new("c_phone", DataType::Int32, true), - Field::new("c_comment", DataType::Int32, true), + Field::new("p_size", DataType::Int32, true), + Field::new("supplier_cnt", DataType::Int64, true), ]), - 17 => Schema::new(vec![ - Field::new("avg_yearly", DataType::Decimal128(38, 3), true), //TODO should be precision 2 - ]), + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), 18 => Schema::new(vec![ Field::new("c_name", DataType::Utf8, true), @@ -1191,12 +1214,14 @@ fn get_answer_schema(n: usize) -> Schema { Field::new("o_orderkey", DataType::Int64, true), Field::new("o_orderdate", DataType::Date32, true), Field::new("o_totalprice", DataType::Decimal128(15, 2), true), - Field::new("sum_l_quantity", DataType::Decimal128(25, 2), true), + Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true), ]), - 19 => Schema::new(vec![ - Field::new("revenue", DataType::Decimal128(38, 4), true), //TODO should be precision 2 - ]), + 19 => Schema::new(vec![Field::new( + "revenue", + DataType::Decimal128(15, 2), + true, + )]), 20 => Schema::new(vec![ Field::new("s_name", DataType::Utf8, true), @@ -1211,7 +1236,7 @@ fn get_answer_schema(n: usize) -> Schema { 22 => Schema::new(vec![ Field::new("cntrycode", DataType::Utf8, true), Field::new("numcust", DataType::Int64, true), - Field::new("totacctbal", DataType::Decimal128(25, 2), true), + Field::new("totacctbal", DataType::Decimal128(15, 2), true), ]), _ => unimplemented!(), @@ -1452,7 +1477,7 @@ mod tests { run_query(14).await } - #[ignore] // https://github.com/apache/arrow-datafusion/issues/166 + #[ignore] // TODO: support multiline queries #[tokio::test] async fn run_q15() -> Result<()> { run_query(15).await