diff --git a/Cargo.toml b/Cargo.toml index fa0ce2461fafa..01aff810de17b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -125,6 +125,7 @@ datafusion-sql = { path = "datafusion/sql", version = "43.0.0" } datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "43.0.0" } datafusion-substrait = { path = "datafusion/substrait", version = "43.0.0" } doc-comment = "0.3" +enumset = "1.1.5" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fdab12579100a..546a5404348e1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.17" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" dependencies = [ "bzip2", "flate2", @@ -1080,6 +1080,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1097,9 +1107,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1166,6 +1176,40 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "dary_heap" version = "0.3.7" @@ -1357,6 +1401,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", + "enumset", "indexmap", "paste", "recursive", @@ -1496,6 +1541,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", + "enumset", "hashbrown 0.14.5", "indexmap", "itertools", @@ -1688,6 +1734,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enumset" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a4b049558765cef5f0c1a273c3fc57084d768b44d2f98127aef4cceb17293" +dependencies = [ + "enumset_derive", +] + +[[package]] +name = "enumset_derive" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59c3b24c345d8c314966bdc1832f6c2635bfcce8e7cf363bd115987bba2ee242" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -2181,8 +2248,8 @@ dependencies = [ "http 1.1.0", "hyper 1.5.1", "hyper-util", - "rustls 0.23.17", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -2349,6 +2416,12 @@ dependencies = [ "syn", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" @@ -2565,9 +2638,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "litemap" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" @@ -3049,9 +3122,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -3092,7 +3165,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.17", + "rustls 0.23.18", "socket2", "thiserror 2.0.3", "tokio", @@ -3110,7 +3183,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.17", + "rustls 0.23.18", "rustls-pki-types", "slab", "thiserror 2.0.3", @@ -3288,8 +3361,8 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.17", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -3407,9 +3480,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.17" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "once_cell", "ring", @@ -3428,20 +3501,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -3567,7 +3639,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -3830,9 +3915,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.87" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -4038,7 +4123,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.17", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -4184,9 +4269,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.3" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -4607,9 +4692,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", @@ -4619,9 +4704,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", @@ -4652,18 +4737,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 18906536691ef..78aebdad933dd 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -274,9 +274,9 @@ impl AdjustedPrintOptions { // all rows if matches!( plan, - LogicalPlan::Explain(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Analyze(_) + LogicalPlan::Explain(_, _) + | LogicalPlan::DescribeTable(_, _) + | LogicalPlan::Analyze(_, _) ) { self.inner.maxrows = MaxRows::Unlimited; } @@ -311,7 +311,7 @@ async fn create_plan( // Note that cmd is a mutable reference so that create_external_table function can remove all // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &plan { // To support custom formats, treat error as None let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( @@ -323,7 +323,7 @@ async fn create_plan( .await?; } - if let LogicalPlan::Copy(copy_to) = &mut plan { + if let LogicalPlan::Copy(copy_to, _) = &mut plan { let format = config_file_type_from_str(©_to.file_type.get_ext()); register_object_store_and_config_extensions( @@ -412,7 +412,7 @@ mod tests { let ctx = SessionContext::new(); let plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &plan { let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( &ctx, @@ -438,7 +438,7 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Copy(cmd) = &plan { + if let LogicalPlan::Copy(cmd, _) = &plan { let format = config_file_type_from_str(&cmd.file_type.get_ext()); register_object_store_and_config_extensions( &ctx, @@ -492,7 +492,7 @@ mod tests { for statement in statements { //Should not fail let mut plan = create_plan(&ctx, statement).await?; - if let LogicalPlan::Copy(copy_to) = &mut plan { + if let LogicalPlan::Copy(copy_to, _) = &mut plan { assert_eq!(copy_to.output_url, location); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); ctx.runtime_env() diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index c622463de0331..7288aca0ac8f0 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -321,8 +321,8 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') - Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. }, _)) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( "parquet_metadata requires string argument as its input" diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 3d999766e03fb..a452d2540588f 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -493,7 +493,7 @@ mod tests { let ctx = SessionContext::new(); let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; @@ -538,7 +538,7 @@ mod tests { let ctx = SessionContext::new(); let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; @@ -564,7 +564,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; @@ -592,7 +592,7 @@ mod tests { let ctx = SessionContext::new(); let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; @@ -629,7 +629,7 @@ mod tests { let ctx = SessionContext::new(); let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd), _) = &mut plan { ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index bd067be97b8b3..15695992521c0 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -175,7 +175,7 @@ impl AnalyzerRule for RowLevelAccessControl { } fn is_employee_table_scan(plan: &LogicalPlan) -> bool { - if let LogicalPlan::TableScan(scan) = plan { + if let LogicalPlan::TableScan(scan, _) = plan { scan.table_name.table() == "employee" } else { false diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index cb0796bdcf735..2cce1c4a7a491 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -61,10 +61,10 @@ async fn main() -> Result<()> { let expr = col("a") + lit(5); // The same same expression can be created directly, with much more code: - let expr2 = Expr::BinaryExpr(BinaryExpr::new( + let expr2 = Expr::binary_expr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::literal(ScalarValue::Int32(Some(5)))), )); assert_eq!(expr, expr2); @@ -396,20 +396,20 @@ fn type_coercion_demo() -> Result<()> { let coerced_expr = expr .transform(|e| { // Only type coerces binary expressions. - let Expr::BinaryExpr(e) = e else { + let Expr::BinaryExpr(e, _) = e else { return Ok(Transformed::no(e)); }; - if let Expr::Column(ref col_expr) = *e.left { + if let Expr::Column(ref col_expr, _) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr::new( e.left, e.op, Box::new(coerced_right), )))) } else { - Ok(Transformed::no(Expr::BinaryExpr(e))) + Ok(Transformed::no(Expr::binary_expr(e))) } })? .data; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index b42f25437d772..3fa784b6f7677 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -167,7 +167,7 @@ impl ScalarFunctionWrapper { fn replacement(expr: &Expr, args: &[Expr]) -> Result { let result = expr.clone().transform(|e| { let r = match e { - Expr::Placeholder(placeholder) => { + Expr::Placeholder(placeholder, _) => { let placeholder_position = Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index e0b552620a9af..b1716650966bb 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -145,7 +145,7 @@ impl MyOptimizerRule { expr.transform_up(|expr| { // Closure called for each sub tree match expr { - Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { + Expr::BinaryExpr(binary_expr, _) if is_binary_eq(&binary_expr) => { // destruture the expression let BinaryExpr { left, op: _, right } = binary_expr; // rewrite to `my_eq(left, right)` @@ -171,7 +171,7 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_, _) | Expr::Literal(_, _)) } /// A simple user defined filter function diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index 6faa397ef60f3..3d3300e57c196 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 52a27317e3c3d..707d2ed1d8ff0 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -91,7 +91,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { // as an example for this functionality we replace UDF function // with build-in aggregate function to illustrate the use let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( avg_udaf(), // yes it is the same Avg, `BetterAvgUdaf` was just a // marketing pitch :) diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 117063df4e0d8..410fe73d28fef 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -71,7 +71,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::window_function(WindowFunction { fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, partition_by: window_function.partition_by, diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 2158b8e4b016e..db118eb26cf36 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -39,7 +39,7 @@ fn total_join_count(plan: &LogicalPlan) -> usize { // We can use the TreeNode API to walk over a LogicalPlan. plan.apply(|node| { // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { total += 1; } Ok(TreeNodeRecursion::Continue) @@ -89,7 +89,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { while let Some(node) = to_visit.pop() { // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { let (group_count, inputs) = count_tree(node); total += group_count; groups.push(group_count); @@ -146,12 +146,12 @@ fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { // / \ // B C // we can continue the recursion in this case - if let LogicalPlan::Projection(_) = node { + if let LogicalPlan::Projection(_, _) = node { return Ok(TreeNodeRecursion::Continue); } // any join we count - if matches!(node, LogicalPlan::Join(_)) { + if matches!(node, LogicalPlan::Join(_, _)) { total += 1; Ok(TreeNodeRecursion::Continue) } else { diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index d771930de25de..353abdbf97110 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -213,7 +213,7 @@ pub trait TableProvider: Debug + Sync + Send { /// let support: Vec<_> = filters.iter().map(|expr| { /// match expr { /// // This example only supports a between expr with a single column named "c1". - /// Expr::Between(between_expr) => { + /// Expr::Between(between_expr, _) => { /// between_expr.expr /// .try_as_col() /// .map(|column| { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0c153583e34b1..d9800c9e24af8 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -686,6 +686,11 @@ impl Transformed { Self::new(data, false, TreeNodeRecursion::Continue) } + /// Wrapper for unchanged data with [`TreeNodeRecursion::Jump`] statement. + pub fn jump(data: T) -> Self { + Self::new(data, false, TreeNodeRecursion::Jump) + } + /// Applies an infallible `f` to the data of this [`Transformed`] object, /// without modifying the `transformed` flag. pub fn update_data U>(self, f: F) -> Transformed { diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index e4c5f7c5deb3b..52ee21cc1a946 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -71,8 +71,8 @@ fn criterion_benchmark(c: &mut Criterion) { let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::literal(ScalarValue::Utf8(Some(keys[i].clone())))); + value_buffer.push(Expr::literal(ScalarValue::Int32(Some(values[i])))); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 82ee52d7b2e39..ef8bed391503f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -253,7 +253,7 @@ impl DataFrame { .collect::>>()?; let expr: Vec = fields .into_iter() - .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) + .map(|(qualifier, field)| Expr::column(Column::from((qualifier, field)))) .collect(); self.select(expr) } @@ -369,7 +369,7 @@ impl DataFrame { .enumerate() .map(|(idx, _)| self.plan.schema().qualified_field(idx)) .filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f))) - .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) + .map(|(qualifier, field)| Expr::column(Column::from((qualifier, field)))) .collect(); self.select(expr) } @@ -513,7 +513,7 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? @@ -527,7 +527,7 @@ impl DataFrame { .into_iter() .enumerate() .filter(|(idx, _)| *idx != grouping_id_pos) - .map(|(_, column)| Expr::Column(column)) + .map(|(_, column)| Expr::column(column)) .collect::>(); LogicalPlanBuilder::from(plan).project(exprs)?.build()? } else { @@ -1164,7 +1164,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate(vec![], vec![count(Expr::literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1403,7 +1403,7 @@ impl DataFrame { /// # } /// ``` pub fn explain(self, verbose: bool, analyze: bool) -> Result { - if matches!(self.plan, LogicalPlan::Explain(_)) { + if matches!(self.plan, LogicalPlan::Explain(_, _)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = LogicalPlanBuilder::from(self.plan) @@ -2175,7 +2175,7 @@ mod tests { async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(WindowFunction::new( + let first_row = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::WindowUDF(first_value_udwf()), vec![col("aggregate_test_100.c1")], )) @@ -2741,7 +2741,7 @@ mod tests { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::window_function(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -3007,7 +3007,7 @@ mod tests { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::literal(ScalarValue::Null)), )?; let expected_plan = "EmptyRelation"; assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?)); diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 04c64156b125b..ddf5934d4b151 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -54,7 +54,7 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { + Expr::Column(Column { ref name, .. }, _) => { is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) @@ -62,34 +62,34 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) - | Expr::Alias(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarVariable(_, _) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::Cast(_) - | Expr::TryCast(_) - | Expr::BinaryExpr(_) - | Expr::Between(_) - | Expr::Like(_) - | Expr::SimilarTo(_) - | Expr::InList(_) - | Expr::Exists(_) - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) - | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), - - Expr::ScalarFunction(scalar_function) => { + Expr::Literal(_, _) + | Expr::Alias(_, _) + | Expr::OuterReferenceColumn(_, _, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) + | Expr::BinaryExpr(_, _) + | Expr::Between(_, _) + | Expr::Like(_, _) + | Expr::SimilarTo(_, _) + | Expr::InList(_, _) + | Expr::Exists(_, _) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) + | Expr::GroupingSet(_, _) + | Expr::Case(_, _) => Ok(TreeNodeRecursion::Continue), + + Expr::ScalarFunction(scalar_function, _) => { match scalar_function.func.signature().volatility { Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context @@ -108,7 +108,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::WindowFunction { .. } | Expr::Wildcard { .. } | Expr::Unnest { .. } - | Expr::Placeholder(_) => { + | Expr::Placeholder(_, _) => { is_applicable = false; Ok(TreeNodeRecursion::Stop) } @@ -330,16 +330,19 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter + if let Expr::BinaryExpr( + BinaryExpr { + ref left, + op, + ref right, + }, + _, + ) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { ref name, .. }, _), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. }, _)) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -868,7 +871,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::literal(ScalarValue::Date32(Some(3))))], ), Some(Path::from("a=1970-01-04")), ); @@ -877,7 +880,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( + &[col("a").eq(Expr::literal(ScalarValue::Date64(Some( 4 * 24 * 60 * 60 * 1000 )))),], ), diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index ffe49dd2ba116..c35cdba9aaefd 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1954,10 +1954,10 @@ mod tests { false, )])); - let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( - Box::new(Expr::Column("column1".into())), + let filter_predicate = Expr::binary_expr(BinaryExpr::new( + Box::new(Expr::column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::literal(ScalarValue::Int32(Some(0)))), )); // Create a new batch of data to insert into the table diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index ad369b75e1301..ddb7f83e26390 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -65,7 +65,7 @@ fn create_ordering( let mut sort_exprs = LexOrdering::default(); for sort in exprs { match &sort.expr { - Expr::Column(col) => match expressions::col(&col.name, schema) { + Expr::Column(col, _) => match expressions::col(&col.name, schema) { Ok(expr) => { sort_exprs.push(PhysicalSortExpr { expr, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index a97e7c7d2552c..7424ba659c55e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -432,7 +432,7 @@ pub fn can_expr_be_pushed_down_with_schemas( ) -> bool { let mut can_be_pushed = true; expr.apply(|expr| match expr { - datafusion_expr::Expr::Column(column) => { + datafusion_expr::Expr::Column(column, _) => { can_be_pushed &= !would_column_prevent_pushdown(column.name(), file_schema, table_schema); Ok(if can_be_pushed { @@ -699,7 +699,7 @@ mod test { .expect("expected error free record batch"); // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( + let expr = col("timestamp_col").lt(Expr::literal( ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), )); let expr = logical2physical(&expr, &table_schema); @@ -723,7 +723,7 @@ mod test { assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( + let expr = col("timestamp_col").gt(Expr::literal( ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), )); let expr = logical2physical(&expr, &table_schema); @@ -826,7 +826,7 @@ mod test { let expr = col("str_col") .is_not_null() - .or(col("int_col").gt(Expr::Literal(ScalarValue::UInt64(Some(5))))); + .or(col("int_col").gt(Expr::literal(ScalarValue::UInt64(Some(5))))); assert!(can_expr_be_pushed_down_with_schemas( &expr, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 7406676652f66..20db62e03b79a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -1231,10 +1231,10 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + .eq(Expr::literal(ScalarValue::Utf8View(Some(String::from( "Hello_Not_Exists", ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + .or(col(r#""String""#).eq(Expr::literal(ScalarValue::Utf8View( Some(String::from("Hello_Not_Exists2")), )))), ), @@ -1316,13 +1316,13 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + .eq(Expr::literal(ScalarValue::Utf8View(Some(String::from( "Hello", ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + .or(col(r#""String""#).eq(Expr::literal(ScalarValue::Utf8View( Some(String::from("the quick")), )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + .or(col(r#""String""#).eq(Expr::literal(ScalarValue::Utf8View( Some(String::from("are you")), )))), ) diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 1ffe54e4b06c1..9c2e6e6bb6f0b 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -139,7 +139,7 @@ impl TableProvider for ViewTable { let fields: Vec = projection .iter() .map(|i| { - Expr::Column(Column::from( + Expr::column(Column::from( self.logical_plan.schema().qualified_field(*i), )) }) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5f01d41c31e73..18104e071373b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -648,7 +648,7 @@ impl SessionContext { /// [`SQLOptions::verify_plan`]. pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result { match plan { - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { // Box::pin avoids allocating the stack space within this function's frame // for every one of these individual async functions, decreasing the risk of // stack overflows. @@ -681,18 +681,21 @@ impl SessionContext { DdlStatement::DropFunction(cmd) => { Box::pin(self.drop_function(cmd)).await } - ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))), + ddl => Ok(DataFrame::new(self.state(), LogicalPlan::ddl(ddl))), } } // TODO what about the other statements (like TransactionStart and TransactionEnd) - LogicalPlan::Statement(Statement::SetVariable(stmt)) => { + LogicalPlan::Statement(Statement::SetVariable(stmt), _) => { self.set_variable(stmt).await } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - input, - data_types, - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, + input, + data_types, + }), + _, + ) => { // The number of parameters must match the specified data types length. if !data_types.is_empty() { let param_names = input.get_parameter_names()?; @@ -712,10 +715,10 @@ impl SessionContext { self.state.write().store_prepared(name, data_types, input)?; self.return_empty_dataframe() } - LogicalPlan::Statement(Statement::Execute(execute)) => { + LogicalPlan::Statement(Statement::Execute(execute), _) => { self.execute_prepared(execute) } - LogicalPlan::Statement(Statement::Deallocate(deallocate)) => { + LogicalPlan::Statement(Statement::Deallocate(deallocate), _) => { self.state .write() .remove_prepared(deallocate.name.as_str())?; @@ -1133,7 +1136,7 @@ impl SessionContext { let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), + Expr::Literal(scalar, _) => Ok(scalar), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; @@ -1769,16 +1772,16 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { fn f_down(&mut self, node: &'n Self::Node) -> Result { match node { - LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { + LogicalPlan::Ddl(ddl, _) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) } - LogicalPlan::Dml(dml) if !self.options.allow_dml => { + LogicalPlan::Dml(dml, _) if !self.options.allow_dml => { plan_err!("DML not supported: {}", dml.op) } - LogicalPlan::Copy(_) if !self.options.allow_dml => { + LogicalPlan::Copy(_, _) if !self.options.allow_dml => { plan_err!("DML not supported: COPY") } - LogicalPlan::Statement(stmt) if !self.options.allow_statements => { + LogicalPlan::Statement(stmt, _) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } _ => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index e99cf82223815..d6d126d06326d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -632,7 +632,7 @@ impl SessionState { /// Optimizes the logical plan by applying optimizer rules. pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result { - if let LogicalPlan::Explain(e) = plan { + if let LogicalPlan::Explain(e, _) = plan { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule @@ -652,7 +652,7 @@ impl SessionState { stringified_plans .push(StringifiedPlan::new(plan_type, err.to_string())); - return Ok(LogicalPlan::Explain(Explain { + return Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan: Arc::clone(&e.plan), stringified_plans, @@ -688,7 +688,7 @@ impl SessionState { Err(e) => return Err(e), }; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan, stringified_plans, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 44537c951f945..659e6f0e2dec3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -438,13 +438,16 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + source, + projection, + filters, + fetch, + .. + }, + _, + ) => { let source = source_as_provider(source)?; // Remove all qualifiers from the scan as the provider // doesn't know (nor should care) how the relation was @@ -454,7 +457,7 @@ impl DefaultPhysicalPlanner { .scan(session_state, projection.as_ref(), &filters, *fetch) .await? } - LogicalPlan::Values(Values { values, schema }) => { + LogicalPlan::Values(Values { values, schema }, _) => { let exec_schema = schema.as_ref().to_owned().into(); let exprs = values .iter() @@ -469,34 +472,46 @@ impl DefaultPhysicalPlanner { let value_exec = ValuesExec::try_new(SchemaRef::new(exec_schema), exprs)?; Arc::new(value_exec) } - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema, - }) => Arc::new(EmptyExec::new(SchemaRef::new( + LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema, + }, + _, + ) => Arc::new(EmptyExec::new(SchemaRef::new( schema.as_ref().to_owned().into(), ))), - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: true, - schema, - }) => Arc::new(PlaceholderRowExec::new(SchemaRef::new( + LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: true, + schema, + }, + _, + ) => Arc::new(PlaceholderRowExec::new(SchemaRef::new( schema.as_ref().to_owned().into(), ))), - LogicalPlan::DescribeTable(DescribeTable { - schema, - output_schema, - }) => { + LogicalPlan::DescribeTable( + DescribeTable { + schema, + output_schema, + }, + _, + ) => { let output_schema: Schema = output_schema.as_ref().into(); self.plan_describe(Arc::clone(schema), Arc::new(output_schema))? } // 1 Child - LogicalPlan::Copy(CopyTo { - input, - output_url, - file_type, - partition_by, - options: source_option_tuples, - }) => { + LogicalPlan::Copy( + CopyTo { + input, + output_url, + file_type, + partition_by, + options: source_option_tuples, + }, + _, + ) => { let input_exec = children.one()?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); @@ -539,11 +554,14 @@ impl DefaultPhysicalPlanner { .create_writer_physical_plan(input_exec, session_state, config, None) .await? } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::Insert(insert_op), - .. - }) => { + LogicalPlan::Dml( + DmlStatement { + table_name, + op: WriteOp::Insert(insert_op), + .. + }, + _, + ) => { let name = table_name.table(); let schema = session_state.schema_for_ref(table_name.clone())?; if let Some(provider) = schema.table(name).await? { @@ -555,9 +573,12 @@ impl DefaultPhysicalPlanner { return exec_err!("Table '{table_name}' does not exist"); } } - LogicalPlan::Window(Window { - input, window_expr, .. - }) => { + LogicalPlan::Window( + Window { + input, window_expr, .. + }, + _, + ) => { if window_expr.is_empty() { return internal_err!("Impossibly got empty window expression"); } @@ -584,19 +605,25 @@ impl DefaultPhysicalPlanner { }; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), - Expr::Alias(Alias { expr, .. }) => { + Expr::WindowFunction( + WindowFunction { + ref partition_by, + ref order_by, + .. + }, + _, + ) => generate_sort_key(partition_by, order_by), + Expr::Alias(Alias { expr, .. }, _) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction( + WindowFunction { + ref partition_by, + ref order_by, + .. + }, + _, + ) => generate_sort_key(partition_by, order_by), _ => unreachable!(), } } @@ -643,12 +670,15 @@ impl DefaultPhysicalPlanner { )?) } } - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + .. + }, + _, + ) => { let options = session_state.config().options(); // Initially need to perform the aggregate and then merge the partitions let input_exec = children.one()?; @@ -758,16 +788,19 @@ impl DefaultPhysicalPlanner { Arc::clone(&physical_input_schema), )?) } - LogicalPlan::Projection(Projection { input, expr, .. }) => self + LogicalPlan::Projection(Projection { input, expr, .. }, _) => self .create_project_physical_exec( session_state, children.one()?, input, expr, )?, - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { + LogicalPlan::Filter( + Filter { + predicate, input, .. + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.schema(); @@ -781,10 +814,13 @@ impl DefaultPhysicalPlanner { let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -815,9 +851,12 @@ impl DefaultPhysicalPlanner { physical_partitioning, )?) } - LogicalPlan::Sort(Sort { - expr, input, fetch, .. - }) => { + LogicalPlan::Sort( + Sort { + expr, input, fetch, .. + }, + _, + ) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); let sort_expr = create_physical_sort_exprs( @@ -829,9 +868,9 @@ impl DefaultPhysicalPlanner { SortExec::new(sort_expr, physical_input).with_fetch(*fetch); Arc::new(new_sort) } - LogicalPlan::Subquery(_) => todo!(), - LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(limit) => { + LogicalPlan::Subquery(_, _) => todo!(), + LogicalPlan::SubqueryAlias(_, _) => children.one()?, + LogicalPlan::Limit(limit, _) => { let input = children.one()?; let SkipType::Literal(skip) = limit.get_skip_type()? else { return not_impl_err!( @@ -861,13 +900,16 @@ impl DefaultPhysicalPlanner { Arc::new(GlobalLimitExec::new(input, skip, fetch)) } - LogicalPlan::Unnest(Unnest { - list_type_columns, - struct_type_columns, - schema, - options, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + list_type_columns, + struct_type_columns, + schema, + options, + .. + }, + _, + ) => { let input = children.one()?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); let list_column_indices = list_type_columns @@ -887,23 +929,26 @@ impl DefaultPhysicalPlanner { } // 2 Children - LogicalPlan::Join(Join { - left, - right, - on: keys, - filter, - join_type, - null_equals_null, - schema: join_schema, - .. - }) => { + LogicalPlan::Join( + Join { + left, + right, + on: keys, + filter, + join_type, + null_equals_null, + schema: join_schema, + .. + }, + _, + ) => { let null_equals_null = *null_equals_null; let [physical_left, physical_right] = children.two()?; // If join has expression equijoin keys, add physical projection. let has_expr_join_key = keys.iter().any(|(l, r)| { - !(matches!(l, Expr::Column(_)) && matches!(r, Expr::Column(_))) + !(matches!(l, Expr::Column(_, _)) && matches!(r, Expr::Column(_, _))) }); let (new_logical, physical_left, physical_right) = if has_expr_join_key { // TODO: Can we extract this transformation to somewhere before physical plan @@ -925,7 +970,7 @@ impl DefaultPhysicalPlanner { let left = Arc::new(left); let right = Arc::new(right); - let new_join = LogicalPlan::Join(Join::try_new_with_project_input( + let new_join = LogicalPlan::join(Join::try_new_with_project_input( node, Arc::clone(&left), Arc::clone(&right), @@ -938,7 +983,7 @@ impl DefaultPhysicalPlanner { // If left_projected is true we are guaranteed that left is a Projection ( true, - LogicalPlan::Projection(Projection { input, expr, .. }), + LogicalPlan::Projection(Projection { input, expr, .. }, _), ) => self.create_project_physical_exec( session_state, physical_left, @@ -951,7 +996,7 @@ impl DefaultPhysicalPlanner { // If right_projected is true we are guaranteed that right is a Projection ( true, - LogicalPlan::Projection(Projection { input, expr, .. }), + LogicalPlan::Projection(Projection { input, expr, .. }, _), ) => self.create_project_physical_exec( session_state, physical_right, @@ -965,7 +1010,7 @@ impl DefaultPhysicalPlanner { if left_projected || right_projected { let final_join_result = join_schema.iter().map(Expr::from).collect::>(); - let projection = LogicalPlan::Projection(Projection::try_new( + let projection = LogicalPlan::projection(Projection::try_new( final_join_result, Arc::new(new_join), )?); @@ -982,19 +1027,25 @@ impl DefaultPhysicalPlanner { // Retrieving new left/right and join keys (in case plan was mutated above) let (left, right, keys, new_project) = match new_logical.as_ref() { - LogicalPlan::Projection(Projection { input, expr, .. }) => { - if let LogicalPlan::Join(Join { - left, right, on, .. - }) = input.as_ref() + LogicalPlan::Projection(Projection { input, expr, .. }, _) => { + if let LogicalPlan::Join( + Join { + left, right, on, .. + }, + _, + ) = input.as_ref() { (left, right, on, Some((input, expr))) } else { unreachable!() } } - LogicalPlan::Join(Join { - left, right, on, .. - }) => (left, right, on, None), + LogicalPlan::Join( + Join { + left, right, on, .. + }, + _, + ) => (left, right, on, None), // Should either be the original Join, or Join with a Projection on top _ => unreachable!(), }; @@ -1170,9 +1221,12 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. - }) => { + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, is_distinct, .. + }, + _, + ) => { let [static_term, recursive_term] = children.two()?; Arc::new(RecursiveQueryExec::try_new( name.clone(), @@ -1183,8 +1237,8 @@ impl DefaultPhysicalPlanner { } // N Children - LogicalPlan::Union(_) => Arc::new(UnionExec::new(children.vec())), - LogicalPlan::Extension(Extension { node }) => { + LogicalPlan::Union(_, _) => Arc::new(UnionExec::new(children.vec())), + LogicalPlan::Extension(Extension { node }, _) => { let mut maybe_plan = None; let children = children.vec(); for planner in &self.extension_planners { @@ -1224,16 +1278,16 @@ impl DefaultPhysicalPlanner { } // Other - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this let name = statement.name(); return not_impl_err!("Unsupported logical plan: Statement({name})"); } - LogicalPlan::Dml(dml) => { + LogicalPlan::Dml(dml, _) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this return not_impl_err!("Unsupported logical plan: Dml({0})", dml.op); } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { // There is no default plan for DDl statements -- // it must be handled at a higher level (so that // the appropriate table can be registered with @@ -1241,17 +1295,17 @@ impl DefaultPhysicalPlanner { let name = ddl.name(); return not_impl_err!("Unsupported logical plan: {name}"); } - LogicalPlan::Explain(_) => { + LogicalPlan::Explain(_, _) => { return internal_err!( "Unsupported logical plan: Explain must be root of the plan" ) } - LogicalPlan::Distinct(_) => { + LogicalPlan::Distinct(_, _) => { return internal_err!( "Unsupported logical plan: Distinct should be replaced to Aggregate" ) } - LogicalPlan::Analyze(_) => { + LogicalPlan::Analyze(_, _) => { return internal_err!( "Unsupported logical plan: Analyze must be root of the plan" ) @@ -1269,7 +1323,7 @@ impl DefaultPhysicalPlanner { ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { - Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets), _) => { merge_grouping_set_physical_expr( grouping_sets, input_dfschema, @@ -1277,13 +1331,15 @@ impl DefaultPhysicalPlanner { session_state, ) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { + create_cube_physical_expr( + exprs, + input_dfschema, + input_schema, + session_state, + ) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { create_rollup_physical_expr( exprs, input_dfschema, @@ -1516,14 +1572,17 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1564,7 +1623,7 @@ pub fn create_window_expr( ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), + Expr::Alias(Alias { expr, name, .. }, _) => (name.clone(), expr.as_ref()), _ => (e.schema_name().to_string(), e), }; create_window_expr_with_name(e, name, logical_schema, execution_props) @@ -1587,14 +1646,17 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( execution_props: &ExecutionProps, ) -> Result { match e { - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - args, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + args, + filter, + order_by, + null_treatment, + }, + _, + ) => { let name = if let Some(name) = name { name } else { @@ -1656,8 +1718,8 @@ pub fn create_aggregate_expr_and_maybe_filter( ) -> Result { // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => (Some(name.clone()), expr.as_ref()), - Expr::AggregateFunction(_) => (Some(e.schema_name().to_string()), e), + Expr::Alias(Alias { expr, name, .. }, _) => (Some(name.clone()), expr.as_ref()), + Expr::AggregateFunction(_, _) => (Some(e.schema_name().to_string()), e), _ => (None, e), }; @@ -1713,7 +1775,7 @@ impl DefaultPhysicalPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result>> { - if let LogicalPlan::Explain(e) = logical_plan { + if let LogicalPlan::Explain(e, _) = logical_plan { use PlanType::*; let mut stringified_plans = vec![]; @@ -1836,7 +1898,7 @@ impl DefaultPhysicalPlanner { stringified_plans, e.verbose, )))) - } else if let LogicalPlan::Analyze(a) = logical_plan { + } else if let LogicalPlan::Analyze(a, _) = logical_plan { let input = self.create_physical_plan(&a.input, session_state).await?; let schema = SchemaRef::new((*a.schema).clone().into()); let show_statistics = session_state.config_options().explain.show_statistics; @@ -1969,7 +2031,7 @@ impl DefaultPhysicalPlanner { // // This depends on the invariant that logical schema field index MUST match // with physical schema field index. - let physical_name = if let Expr::Column(col) = e { + let physical_name = if let Expr::Column(col, _) = e { match input_schema.index_of_column(col) { Ok(idx) => { // index physical field using logical field index @@ -2175,7 +2237,7 @@ mod tests { ErrorExtensionPlanner {}, )]); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); match planner @@ -2233,7 +2295,7 @@ mod tests { async fn default_extension_planner() { let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::default(); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner @@ -2260,7 +2322,7 @@ mod tests { BadExtensionPlanner {}, )]); - let logical_plan = LogicalPlan::Extension(Extension { + let logical_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner @@ -2371,7 +2433,7 @@ mod tests { #[tokio::test] async fn hash_agg_grouping_set_input_schema() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], @@ -2446,7 +2508,7 @@ mod tests { #[tokio::test] async fn hash_agg_grouping_set_by_partitioned() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], @@ -2702,7 +2764,7 @@ mod tests { let options = CsvReadOptions::new().schema_infer_max_records(100); let logical_plan = match ctx.read_csv(path, options).await?.into_optimized_plan()? { - LogicalPlan::TableScan(ref scan) => { + LogicalPlan::TableScan(ref scan, _) => { let mut scan = scan.clone(); let table_reference = TableReference::from(name); scan.table_name = table_reference; @@ -2712,7 +2774,7 @@ mod tests { .clone() .replace_qualifier(name.to_string()); scan.projected_schema = Arc::new(new_schema); - LogicalPlan::TableScan(scan) + LogicalPlan::table_scan(scan) } _ => unimplemented!(), }; diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index e1bd14105e23e..38773ef11f89f 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -236,11 +236,14 @@ async fn custom_source_dataframe() -> Result<()> { let optimized_plan = state.optimize(&logical_plan)?; match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + source, + projected_schema, + .. + }, + _, + ) => { assert_eq!(source.schema().fields().len(), 2); assert_eq!(projected_schema.fields().len(), 1); } diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 09f7265d639a7..941dd7960bdc7 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -172,14 +172,14 @@ impl TableProvider for CustomProvider { let empty = Vec::new(); let projection = projection.unwrap_or(&empty); match &filters[0] { - Expr::BinaryExpr(BinaryExpr { right, .. }) => { + Expr::BinaryExpr(BinaryExpr { right, .. }, _) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, + Expr::Cast(Cast { expr, data_type: _ }, _) => match expr.deref() { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1bd90fce839d0..c025c6deb5cc5 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,7 +31,6 @@ use datafusion::prelude::*; use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; use datafusion_functions_nested::map::map; @@ -376,11 +375,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); // the arg2 parameter is a complex expr, but it can be evaluated to the literal value - let alias_expr = Expr::Alias(Alias::new( - cast(lit(0.5), DataType::Float32), - None::<&str>, - "arg_2".to_string(), - )); + let alias_expr = cast(lit(0.5), DataType::Float32).alias("arg_2".to_string()); let expr = approx_percentile_cont(col("b"), alias_expr, None); let df = create_test_table().await?; let expected = [ diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 439aa6147e9b6..e7516432a285e 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -180,7 +180,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { let df_results = ctx .table("t1") .await? - .select(vec![Expr::WindowFunction(expr::WindowFunction::new( + .select(vec![Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) @@ -589,7 +589,7 @@ async fn select_with_alias_overwrite() -> Result<()> { #[tokio::test] async fn test_grouping_sets() -> Result<()> { - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], vec![col("a"), col("b")], @@ -631,7 +631,7 @@ async fn test_grouping_sets() -> Result<()> { async fn test_grouping_sets_count() -> Result<()> { let ctx = SessionContext::new(); - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], ])); @@ -671,7 +671,7 @@ async fn test_grouping_sets_count() -> Result<()> { async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { let ctx = SessionContext::new(); - let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("c1")], vec![col("c2")], vec![col("c1"), col("c2")], @@ -1629,7 +1629,7 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column("t", cast(Expr::literal(ScalarValue::Null), DataType::Int32)) .unwrap(); df.clone().show().await.unwrap(); diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 168bf484e5411..e118ae7226a89 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -40,24 +40,24 @@ async fn count_only_nulls() -> Result<()> { vec![Field::new("col", DataType::Null, true)].into(), HashMap::new(), )?); - let input = Arc::new(LogicalPlan::Values(Values { + let input = Arc::new(LogicalPlan::values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::literal(ScalarValue::Null)], + vec![Expr::literal(ScalarValue::Null)], + vec![Expr::literal(ScalarValue::Null)], ], })); - let input_col_ref = Expr::Column(Column { + let input_col_ref = Expr::column(Column { relation: None, name: "col".to_string(), }); // Aggregation: count(col) AS count - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + let aggregate = LogicalPlan::aggregate(Aggregate::try_new( input, vec![], - vec![Expr::AggregateFunction(AggregateFunction { + vec![Expr::aggregate_function(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), args: vec![input_col_ref], distinct: false, diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 1e6ff8088d0af..49ac161895347 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -190,7 +190,7 @@ fn make_udf_add(volatility: Volatility) -> Arc { } fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast(Cast::new(expr.into(), DataType::Int64)) + Expr::cast(Cast::new(expr.into(), DataType::Int64)) } fn to_timestamp_expr(arg: impl Into) -> Expr { @@ -282,7 +282,7 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + + Expr::literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 123, milliseconds: 0, }))); @@ -391,7 +391,7 @@ fn test_const_evaluator_scalar_functions() { // rand() + (1 + 2) --> rand() + 3 let fun = math::random(); assert_eq!(fun.signature().volatility, Volatility::Volatile); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); let expr = rand.clone() + (lit(1) + lit(2)); let expected = rand + lit(3); test_evaluate(expr, expected); @@ -399,7 +399,7 @@ fn test_const_evaluator_scalar_functions() { // parenthesization matters: can't rewrite // (rand() + 1) + 2 --> (rand() + 1) + 2) let fun = math::random(); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); let expr = (rand + lit(1)) + lit(2); test_evaluate(expr.clone(), expr); } @@ -429,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let expr = Expr::scalar_function(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -439,15 +439,15 @@ fn test_evaluator_udfs() { // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); let expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); + Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), args)); let expected_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); + Expr::scalar_function(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index f17d13a420607..8fd5b74876dd9 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -301,7 +301,7 @@ fn test_inequalities_non_null_bounded() { (col("x").not_between(lit(0), lit(5)), false), (col("x").not_between(lit(5), lit(10)), true), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -309,7 +309,7 @@ fn test_inequalities_non_null_bounded() { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(5)), diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 177427b47d218..32705163bcb2d 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -326,11 +326,14 @@ async fn nyc() -> Result<()> { let optimized_plan = dataframe.into_optimized_plan().unwrap(); match &optimized_plan { - LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { - LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. - }) => { + LogicalPlan::Aggregate(Aggregate { input, .. }, _) => match input.as_ref() { + LogicalPlan::TableScan( + TableScan { + ref projected_schema, + .. + }, + _, + ) => { assert_eq!(2, projected_schema.fields().len()); assert_eq!(projected_schema.field(0).name(), "passenger_count"); assert_eq!(projected_schema.field(1).name(), "fare_amount"); diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index ad9c1280d6b11..2890caa7caa28 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -25,7 +25,6 @@ use datafusion::logical_expr::Operator; use datafusion::prelude::*; use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion_common::ScalarValue; -use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion_expr::BinaryExpr; @@ -40,26 +39,23 @@ impl ExprPlanner for MyCustomPlanner { ) -> Result> { match &expr.op { BinaryOperator::Arrow => { - Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + Ok(PlannerResult::Planned(Expr::binary_expr(BinaryExpr { left: Box::new(expr.left.clone()), right: Box::new(expr.right.clone()), op: Operator::StringConcat, }))) } BinaryOperator::LongArrow => { - Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + Ok(PlannerResult::Planned(Expr::binary_expr(BinaryExpr { left: Box::new(expr.left.clone()), right: Box::new(expr.right.clone()), op: Operator::Plus, }))) } - BinaryOperator::Question => { - Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), - None::<&str>, - format!("{} ? {}", expr.left, expr.right), - )))) - } + BinaryOperator::Question => Ok(PlannerResult::Planned( + Expr::literal(ScalarValue::Boolean(Some(true))) + .alias(format!("{} ? {}", expr.left, expr.right)), + )), _ => Ok(PlannerResult::Original(expr)), } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 520a91aeb4d6f..8b313d13e83e1 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -360,22 +360,25 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - let LogicalPlan::Limit(ref limit) = plan else { + let LogicalPlan::Limit(ref limit, _) = plan else { return Ok(Transformed::no(plan)); }; let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = limit.input.as_ref() + if let LogicalPlan::Sort( + Sort { + ref expr, + ref input, + .. + }, + _, + ) = limit.input.as_ref() { if expr.len() == 1 { // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + return Ok(Transformed::yes(LogicalPlan::extension(Extension { node: Arc::new(TopKPlanNode { k: fetch, input: input.as_ref().clone(), @@ -705,9 +708,9 @@ impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { plan.transform(|plan| { Ok(match plan { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { let expr = Self::analyze_expr(projection.expr.clone())?; - Transformed::yes(LogicalPlan::Projection(Projection::try_new( + Transformed::yes(LogicalPlan::projection(Projection::try_new( expr, projection.input, )?)) @@ -723,9 +726,9 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::literal(ScalarValue::UInt64( i.map(|i| i as u64), ))) } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index cf403e5d640f1..9d4d5321638a0 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -711,7 +711,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { + Expr::cast(datafusion_expr::Cast { expr: Box::new(arg), data_type: DataType::Int64, }) @@ -829,7 +829,7 @@ impl ScalarUDFImpl for TakeUDF { return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); } - let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = + let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)), _)) = arg_exprs.get(2) { if *idx == 0 || *idx == 1 { @@ -980,7 +980,7 @@ impl ScalarFunctionWrapper { fn replacement(expr: &Expr, args: &[Expr]) -> Result { let result = expr.clone().transform(|e| { let r = match e { - Expr::Placeholder(placeholder) => { + Expr::Placeholder(placeholder, _) => { let placeholder_position = Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 0cc156866d4d1..b41eafee89267 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -165,7 +165,7 @@ impl SimpleCsvTable { async fn interpreter_expr(&self, state: &dyn Session) -> Result { use datafusion::logical_expr::expr_rewriter::normalize_col; use datafusion::logical_expr::utils::columnize_expr; - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: true, schema: Arc::new(DFSchema::empty()), }); @@ -176,7 +176,7 @@ impl SimpleCsvTable { )?], Arc::new(plan), ) - .map(LogicalPlan::Projection)?; + .map(LogicalPlan::projection)?; let rbs = collect( state.create_physical_plan(&logical_plan).await?, Arc::new(TaskContext::from(state)), @@ -201,7 +201,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 438662e0642b0..405710f2c0190 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -49,6 +49,7 @@ datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +enumset = { workspace = true } indexmap = { workspace = true } paste = "^1.0" recursive = { workspace = true } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0cab..e9daee918b7c0 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -72,7 +72,7 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), _ => Ok(DataType::Null), }) .collect::>>()?; @@ -88,7 +88,7 @@ impl CaseBuilder { } } - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( self.expr.clone(), self.when_expr .iter() diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8cfee18fe4e21..53d64fe31eeca 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,6 +29,7 @@ use crate::utils::expr_to_columns; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; +use crate::logical_plan::tree_node::{LogicalPlanPattern, LogicalPlanStats}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ @@ -38,6 +39,7 @@ use datafusion_common::{ plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use enumset::enum_set; use sqlparser::ast::{ display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, NullTreatment, RenameSelectItem, ReplaceSelectElement, @@ -91,7 +93,7 @@ use sqlparser::ast::{ /// # use datafusion_common::Column; /// # use datafusion_expr::{lit, col, Expr}; /// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); +/// assert_eq!(expr, Expr::column(Column::from_name("c1"))); /// ``` /// /// [`Expr::Literal`] refer to literal, or constant, values. These are created @@ -104,9 +106,9 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::literal(ScalarValue::Int64(Some(42)))); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::literal(ScalarValue::Int64(None)); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -122,7 +124,7 @@ use sqlparser::ast::{ /// // Use the `+` operator to add two columns together /// let expr = col("c1") + col("c2"); /// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr(binary_expr) = expr { +/// if let Expr::BinaryExpr(binary_expr, _) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// assert_eq!(*binary_expr.right, col("c2")); /// assert_eq!(binary_expr.op, Operator::Plus); @@ -137,10 +139,10 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Operator, Expr}; /// let expr = col("c1").eq(lit(42_i32)); /// assert!(matches!(expr, Expr::BinaryExpr { .. } )); -/// if let Expr::BinaryExpr(binary_expr) = expr { +/// if let Expr::BinaryExpr(binary_expr, _) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::literal(scalar)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -186,7 +188,7 @@ use sqlparser::ast::{ /// let mut scalars = HashSet::new(); /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { -/// if let Expr::Literal(scalar) = e { +/// if let Expr::Literal(scalar, _) = e { /// scalars.insert(scalar); /// } /// // The return value controls whether to continue visiting the tree @@ -208,7 +210,7 @@ use sqlparser::ast::{ /// let expr = col("a").eq(lit(5)).and(col("b").eq(lit(6))); /// // rewrite all references to column "a" to the literal 42 /// let rewritten = expr.transform(|e| { -/// if let Expr::Column(c) = &e { +/// if let Expr::Column(c, _) = &e { /// if &c.name == "a" { /// // return Transformed::yes to indicate the node was changed /// return Ok(Transformed::yes(lit(42))) @@ -221,44 +223,45 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +#[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum Expr { /// An expression with a specific name. - Alias(Alias), + Alias(Alias, LogicalPlanStats), /// A named reference to a qualified field in a schema. - Column(Column), + Column(Column, LogicalPlanStats), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(DataType, Vec, LogicalPlanStats), /// A constant value. - Literal(ScalarValue), + Literal(ScalarValue, LogicalPlanStats), /// A binary expression such as "age > 21" - BinaryExpr(BinaryExpr), + BinaryExpr(BinaryExpr, LogicalPlanStats), /// LIKE expression - Like(Like), + Like(Like, LogicalPlanStats), /// LIKE expression that uses regular expressions - SimilarTo(Like), + SimilarTo(Like, LogicalPlanStats), /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box), + Not(Box, LogicalPlanStats), /// True if argument is not NULL, false otherwise. This expression itself is never NULL. - IsNotNull(Box), + IsNotNull(Box, LogicalPlanStats), /// True if argument is NULL, false otherwise. This expression itself is never NULL. - IsNull(Box), + IsNull(Box, LogicalPlanStats), /// True if argument is true, false otherwise. This expression itself is never NULL. - IsTrue(Box), + IsTrue(Box, LogicalPlanStats), /// True if argument is false, false otherwise. This expression itself is never NULL. - IsFalse(Box), + IsFalse(Box, LogicalPlanStats), /// True if argument is NULL, false otherwise. This expression itself is never NULL. - IsUnknown(Box), + IsUnknown(Box, LogicalPlanStats), /// True if argument is FALSE or NULL, false otherwise. This expression itself is never NULL. - IsNotTrue(Box), + IsNotTrue(Box, LogicalPlanStats), /// True if argument is TRUE OR NULL, false otherwise. This expression itself is never NULL. - IsNotFalse(Box), + IsNotFalse(Box, LogicalPlanStats), /// True if argument is TRUE or FALSE, false otherwise. This expression itself is never NULL. - IsNotUnknown(Box), + IsNotUnknown(Box, LogicalPlanStats), /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box), + Negative(Box, LogicalPlanStats), /// Whether an expression is between a given range. - Between(Between), + Between(Between, LogicalPlanStats), /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -280,61 +283,61 @@ pub enum Expr { /// [ELSE result] /// END /// ``` - Case(Case), + Case(Case, LogicalPlanStats), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. - Cast(Cast), + Cast(Cast, LogicalPlanStats), /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. - TryCast(TryCast), + TryCast(TryCast, LogicalPlanStats), /// Represents the call of a scalar function with a set of arguments. - ScalarFunction(ScalarFunction), + ScalarFunction(ScalarFunction, LogicalPlanStats), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// /// See also [`ExprFunctionExt`] to set these fields. /// /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt - AggregateFunction(AggregateFunction), + AggregateFunction(AggregateFunction, LogicalPlanStats), /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + WindowFunction(WindowFunction, LogicalPlanStats), /// Returns whether the list contains the expr value. - InList(InList), + InList(InList, LogicalPlanStats), /// EXISTS subquery - Exists(Exists), + Exists(Exists, LogicalPlanStats), /// IN subquery - InSubquery(InSubquery), + InSubquery(InSubquery, LogicalPlanStats), /// Scalar subquery - ScalarSubquery(Subquery), + ScalarSubquery(Subquery, LogicalPlanStats), /// Represents a reference to all available fields in a specific schema, /// with an optional (schema) qualifier. /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard(Wildcard), + Wildcard(Wildcard, LogicalPlanStats), /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list - GroupingSet(GroupingSet), + GroupingSet(GroupingSet, LogicalPlanStats), /// A place holder for parameters in a prepared statement /// (e.g. `$foo` or `$1`) - Placeholder(Placeholder), + Placeholder(Placeholder, LogicalPlanStats), /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. - OuterReferenceColumn(DataType, Column), + OuterReferenceColumn(DataType, Column, LogicalPlanStats), /// Unnest expression - Unnest(Unnest), + Unnest(Unnest, LogicalPlanStats), } impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::literal(ScalarValue::Null) } } /// Create an [`Expr`] from a [`Column`] impl From for Expr { fn from(value: Column) -> Self { - Expr::Column(value) + Expr::column(value) } } @@ -371,6 +374,12 @@ pub struct Wildcard { pub options: WildcardOptions, } +impl Wildcard { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.options.stats() + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -389,6 +398,10 @@ impl Unnest { pub fn new_boxed(boxed: Box) -> Self { Self { expr: boxed } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// Alias expression @@ -412,6 +425,10 @@ impl Alias { name: name.into(), } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// Binary expression @@ -430,6 +447,10 @@ impl BinaryExpr { pub fn new(left: Box, op: Operator, right: Box) -> Self { Self { left, op, right } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.left.stats().merge(self.right.stats()) + } } impl Display for BinaryExpr { @@ -445,7 +466,7 @@ impl Display for BinaryExpr { precedence: u8, ) -> fmt::Result { match expr { - Expr::BinaryExpr(child) => { + Expr::BinaryExpr(child, _) => { let p = child.op.precedence(); if p == 0 || p < precedence { write!(f, "({child})")?; @@ -489,6 +510,18 @@ impl Case { else_expr, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr + .iter() + .chain( + self.when_then_expr + .iter() + .flat_map(|(w, t)| vec![w, t]) + .chain(self.else_expr.iter()), + ) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// LIKE expression @@ -519,6 +552,10 @@ impl Like { case_insensitive, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats().merge(self.pattern.stats()) + } } /// BETWEEN expression @@ -544,6 +581,13 @@ impl Between { high, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr + .stats() + .merge(self.low.stats()) + .merge(self.high.stats()) + } } /// ScalarFunction expression invokes a built-in scalar function @@ -560,13 +604,17 @@ impl ScalarFunction { pub fn name(&self) -> &str { self.func.name() } -} -impl ScalarFunction { /// Create a new ScalarFunction expression with a user-defined function (UDF) pub fn new_udf(udf: Arc, args: Vec) -> Self { Self { func: udf, args } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// Access a sub field of a nested type, such as `Field` or `List` @@ -598,6 +646,10 @@ impl Cast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, data_type } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// TryCast Expression @@ -614,6 +666,10 @@ impl TryCast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, data_type } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats() + } } /// SORT expression @@ -730,6 +786,14 @@ impl AggregateFunction { null_treatment, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .chain(self.filter.iter().map(|e| e.as_ref())) + .chain(self.order_by.iter().flatten().map(|s| &s.expr)) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// A function used as a SQL window function @@ -842,6 +906,14 @@ impl WindowFunction { null_treatment: None, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.args + .iter() + .chain(self.partition_by.iter()) + .chain(self.order_by.iter().map(|s| &s.expr)) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } /// EXISTS expression @@ -858,6 +930,10 @@ impl Exists { pub fn new(subquery: Subquery, negated: bool) -> Self { Self { subquery, negated } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.subquery.stats() + } } /// User Defined Aggregate Function @@ -912,6 +988,12 @@ impl InList { negated, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.list + .iter() + .fold(self.expr.stats(), |s, e| s.merge(e.stats())) + } } /// IN subquery @@ -934,6 +1016,10 @@ impl InSubquery { negated, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.stats().merge(self.subquery.stats()) + } } /// Placeholder, representing bind parameter values such as `$1` or `$name`. @@ -991,6 +1077,18 @@ impl GroupingSet { } } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => exprs + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + GroupingSet::GroupingSets(groups) => groups + .iter() + .flatten() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + } + } } /// Additional options for wildcards, e.g. Snowflake `EXCLUDE`/`RENAME` and Bigquery `EXCEPT`. @@ -1026,6 +1124,13 @@ impl WildcardOptions { rename: self.rename, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.replace + .iter() + .flat_map(|prsi| prsi.planned_expressions.iter()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } impl Display for WildcardOptions { @@ -1115,8 +1220,12 @@ impl Expr { /// output schema. We can use this qualified name to reference the field. pub fn qualified_name(&self) -> (Option, String) { match self { - Expr::Column(Column { relation, name }) => (relation.clone(), name.clone()), - Expr::Alias(Alias { relation, name, .. }) => (relation.clone(), name.clone()), + Expr::Column(Column { relation, name }, _) => { + (relation.clone(), name.clone()) + } + Expr::Alias(Alias { relation, name, .. }, _) => { + (relation.clone(), name.clone()) + } _ => (None, self.schema_name().to_string()), } } @@ -1138,7 +1247,7 @@ impl Expr { Expr::Case { .. } => "Case", Expr::Cast { .. } => "Cast", Expr::Column(..) => "Column", - Expr::OuterReferenceColumn(_, _) => "Outer", + Expr::OuterReferenceColumn(_, _, _) => "Outer", Expr::Exists { .. } => "Exists", Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", @@ -1156,7 +1265,7 @@ impl Expr { Expr::Literal(..) => "Literal", Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", - Expr::Placeholder(_) => "Placeholder", + Expr::Placeholder { .. } => "Placeholder", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", @@ -1209,40 +1318,25 @@ impl Expr { /// Return `self LIKE other` pub fn like(self, other: Expr) -> Expr { - Expr::Like(Like::new( - false, - Box::new(self), - Box::new(other), - None, - false, - )) + let like = Like::new(false, Box::new(self), Box::new(other), None, false); + Expr::_like(like) } /// Return `self NOT LIKE other` pub fn not_like(self, other: Expr) -> Expr { - Expr::Like(Like::new( - true, - Box::new(self), - Box::new(other), - None, - false, - )) + let like = Like::new(true, Box::new(self), Box::new(other), None, false); + Expr::_like(like) } /// Return `self ILIKE other` pub fn ilike(self, other: Expr) -> Expr { - Expr::Like(Like::new( - false, - Box::new(self), - Box::new(other), - None, - true, - )) + let like = Like::new(false, Box::new(self), Box::new(other), None, true); + Expr::_like(like) } /// Return `self NOT ILIKE other` pub fn not_ilike(self, other: Expr) -> Expr { - Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) + Expr::_like(Like::new(true, Box::new(self), Box::new(other), None, true)) } /// Return the name to use for the specific Expr @@ -1263,7 +1357,9 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - Expr::Alias(Alias::new(self, None::<&str>, name.into())) + let alias = Alias::new(self, None::<&str>, name.into()); + let stats = alias.stats(); + Expr::Alias(alias, stats) } /// Return `self AS name` alias expression with a specific qualifier @@ -1272,7 +1368,9 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - Expr::Alias(Alias::new(self, relation, name.into())) + let alias = Alias::new(self, relation, name.into()); + let stats = alias.stats(); + Expr::Alias(alias, stats) } /// Remove an alias from an expression if one exists. @@ -1297,7 +1395,7 @@ impl Expr { /// ``` pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => *alias.expr, + Expr::Alias(alias, _) => *alias.expr, _ => self, } } @@ -1328,7 +1426,9 @@ impl Expr { // f_down: skip subqueries. Check in f_down to avoid recursing into them let recursion = if matches!( expr, - Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) + Expr::Exists { .. } + | Expr::ScalarSubquery { .. } + | Expr::InSubquery { .. } ) { // Subqueries could contain aliases so don't recurse into those TreeNodeRecursion::Jump @@ -1340,7 +1440,7 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { + if let Expr::Alias(Alias { expr, .. }, _) = expr { Ok(Transformed::yes(*expr)) } else { Ok(Transformed::no(expr)) @@ -1354,17 +1454,17 @@ impl Expr { /// Return `self IN ` if `negated` is false, otherwise /// return `self NOT IN `.a pub fn in_list(self, list: Vec, negated: bool) -> Expr { - Expr::InList(InList::new(Box::new(self), list, negated)) + Expr::_in_list(InList::new(Box::new(self), list, negated)) } /// Return `IsNull(Box(self)) pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) + Expr::_is_null(Box::new(self)) } /// Return `IsNotNull(Box(self)) pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) + Expr::_is_not_null(Box::new(self)) } /// Create a sort configuration from an existing expression. @@ -1379,37 +1479,37 @@ impl Expr { /// Return `IsTrue(Box(self))` pub fn is_true(self) -> Expr { - Expr::IsTrue(Box::new(self)) + Expr::_is_true(Box::new(self)) } /// Return `IsNotTrue(Box(self))` pub fn is_not_true(self) -> Expr { - Expr::IsNotTrue(Box::new(self)) + Expr::_is_not_true(Box::new(self)) } /// Return `IsFalse(Box(self))` pub fn is_false(self) -> Expr { - Expr::IsFalse(Box::new(self)) + Expr::_is_false(Box::new(self)) } /// Return `IsNotFalse(Box(self))` pub fn is_not_false(self) -> Expr { - Expr::IsNotFalse(Box::new(self)) + Expr::_is_not_false(Box::new(self)) } /// Return `IsUnknown(Box(self))` pub fn is_unknown(self) -> Expr { - Expr::IsUnknown(Box::new(self)) + Expr::_is_unknown(Box::new(self)) } /// Return `IsNotUnknown(Box(self))` pub fn is_not_unknown(self) -> Expr { - Expr::IsNotUnknown(Box::new(self)) + Expr::_is_not_unknown(Box::new(self)) } /// return `self BETWEEN low AND high` pub fn between(self, low: Expr, high: Expr) -> Expr { - Expr::Between(Between::new( + Expr::_between(Between::new( Box::new(self), false, Box::new(low), @@ -1419,7 +1519,7 @@ impl Expr { /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { - Expr::Between(Between::new( + Expr::_between(Between::new( Box::new(self), true, Box::new(low), @@ -1430,7 +1530,7 @@ impl Expr { #[deprecated(since = "39.0.0", note = "use try_as_col instead")] pub fn try_into_col(&self) -> Result { match self { - Expr::Column(it) => Ok(it.clone()), + Expr::Column(it, _) => Ok(it.clone()), _ => plan_err!("Could not coerce '{self}' into Column!"), } } @@ -1453,7 +1553,7 @@ impl Expr { /// assert_eq!(expr.try_as_col(), None); /// ``` pub fn try_as_col(&self) -> Option<&Column> { - if let Expr::Column(it) = self { + if let Expr::Column(it, _) = self { Some(it) } else { None @@ -1468,9 +1568,9 @@ impl Expr { /// or a `Cast` expression that wraps a `Column`. pub fn get_as_join_column(&self) -> Option<&Column> { match self { - Expr::Column(c) => Some(c), - Expr::Cast(Cast { expr, .. }) => match &**expr { - Expr::Column(c) => Some(c), + Expr::Column(c, _) => Some(c), + Expr::Cast(Cast { expr, .. }, _) => match &**expr { + Expr::Column(c, _) => Some(c), _ => None, }, _ => None, @@ -1512,7 +1612,7 @@ impl Expr { /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { self.apply(|expr| { - if let Expr::Column(col) = expr { + if let Expr::Column(col, _) = expr { set.insert(col); } Ok(TreeNodeRecursion::Continue) @@ -1547,7 +1647,7 @@ impl Expr { /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { self.apply(|expr| { - if let Expr::Column(col) = expr { + if let Expr::Column(col, _) = expr { *map.entry(col).or_default() += 1; } Ok(TreeNodeRecursion::Continue) @@ -1557,7 +1657,7 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + self.exists(|expr| Ok(matches!(expr, Expr::Column(_, _)))) .expect("exists closure is infallible") } @@ -1573,7 +1673,7 @@ impl Expr { /// - `rand()` returns `true`, /// - `a + rand()` returns `false` pub fn is_volatile_node(&self) -> bool { - matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) + matches!(self, Expr::ScalarFunction(func, _) if func.func.signature().volatility == Volatility::Volatile) } /// Returns true if the expression is volatile, i.e. whether it can return different @@ -1600,21 +1700,24 @@ impl Expr { let mut has_placeholder = false; self.transform(|mut expr| { // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }, _) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr + if let Expr::Between( + Between { + expr, + negated: _, + low, + high, + }, + _, + ) = &mut expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } - if let Expr::Placeholder(_) = &expr { + if let Expr::Placeholder(_, _) = &expr { has_placeholder = true; } Ok(Transformed::yes(expr)) @@ -1627,8 +1730,8 @@ impl Expr { /// and thus any side effects (like divide by zero) may not be encountered pub fn short_circuits(&self) -> bool { match self { - Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), - Expr::BinaryExpr(BinaryExpr { op, .. }) => { + Expr::ScalarFunction(ScalarFunction { func, .. }, _) => func.short_circuits(), + Expr::BinaryExpr(BinaryExpr { op, .. }, _) => { matches!(op, Operator::And | Operator::Or) } Expr::Case { .. } => true, @@ -1654,11 +1757,11 @@ impl Expr { | Expr::IsUnknown(..) | Expr::Like(..) | Expr::ScalarSubquery(..) - | Expr::ScalarVariable(_, _) + | Expr::ScalarVariable(_, _, _) | Expr::SimilarTo(..) | Expr::Not(..) | Expr::Negative(..) - | Expr::OuterReferenceColumn(_, _) + | Expr::OuterReferenceColumn(_, _, _) | Expr::TryCast(..) | Expr::Unnest(..) | Expr::Wildcard { .. } @@ -1667,6 +1770,183 @@ impl Expr { | Expr::Placeholder(..) => false, } } + + pub fn wildcard(wildcard: Wildcard) -> Self { + let stats = wildcard.stats(); + Expr::Wildcard(wildcard, stats) + } + + pub fn binary_expr(binary_expr: BinaryExpr) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprBinaryExpr)) + .merge(binary_expr.stats()); + Expr::BinaryExpr(binary_expr, stats) + } + + pub fn similar_to(like: Like) -> Self { + let stats = like.stats(); + Expr::SimilarTo(like, stats) + } + + pub fn _like(like: Like) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprLike)) + .merge(like.stats()); + Expr::Like(like, stats) + } + + pub fn unnest(unnest: Unnest) -> Self { + let stats = unnest.stats(); + Expr::Unnest(unnest, stats) + } + + pub fn in_subquery(in_subquery: InSubquery) -> Self { + let stats = in_subquery.stats(); + Expr::InSubquery(in_subquery, stats) + } + + pub fn scalar_subquery(subquery: Subquery) -> Self { + let stats = subquery.stats(); + Expr::ScalarSubquery(subquery, stats) + } + + pub fn _not(expr: Box) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprNot)) + .merge(expr.stats()); + Expr::Not(expr, stats) + } + + pub fn _is_not_null(expr: Box) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsNotNull)) + .merge(expr.stats()); + Expr::IsNotNull(expr, stats) + } + + pub fn _is_null(expr: Box) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsNull)) + .merge(expr.stats()); + Expr::IsNull(expr, stats) + } + + pub fn _is_true(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsTrue(expr, stats) + } + + pub fn _is_false(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsFalse(expr, stats) + } + + pub fn _is_unknown(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsUnknown(expr, stats) + } + + pub fn _is_not_true(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotTrue(expr, stats) + } + + pub fn _is_not_false(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotFalse(expr, stats) + } + + pub fn _is_not_unknown(expr: Box) -> Self { + let stats = expr.stats(); + Expr::IsNotUnknown(expr, stats) + } + + pub fn negative(expr: Box) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprNegative)) + .merge(expr.stats()); + Expr::Negative(expr, stats) + } + + pub fn _between(between: Between) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprBetween)) + .merge(between.stats()); + Expr::Between(between, stats) + } + + pub fn case(case: Case) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprCase)) + .merge(case.stats()); + Expr::Case(case, stats) + } + + pub fn cast(cast: Cast) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprCast)) + .merge(cast.stats()); + Expr::Cast(cast, stats) + } + + pub fn try_cast(try_cast: TryCast) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprTryCast)) + .merge(try_cast.stats()); + Expr::TryCast(try_cast, stats) + } + + pub fn scalar_function(scalar_function: ScalarFunction) -> Self { + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprScalarFunction)) + .merge(scalar_function.stats()); + Expr::ScalarFunction(scalar_function, stats) + } + + pub fn window_function(window_function: WindowFunction) -> Self { + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprWindowFunction)) + .merge(window_function.stats()); + Expr::WindowFunction(window_function, stats) + } + + pub fn aggregate_function(aggregate_function: AggregateFunction) -> Self { + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprAggregateFunction)) + .merge(aggregate_function.stats()); + Expr::AggregateFunction(aggregate_function, stats) + } + + pub fn grouping_set(grouping_set: GroupingSet) -> Self { + let stats = grouping_set.stats(); + Expr::GroupingSet(grouping_set, stats) + } + + pub fn _in_list(in_list: InList) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprInList)) + .merge(in_list.stats()); + Expr::InList(in_list, stats) + } + + pub fn exists(exists: Exists) -> Self { + let stats = exists.stats(); + Expr::Exists(exists, stats) + } + + pub fn literal(scalar_value: ScalarValue) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprLiteral)); + Expr::Literal(scalar_value, stats) + } + + pub fn column(column: Column) -> Self { + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprColumn)); + Expr::Column(column, stats) + } + + pub fn outer_reference_column(data_type: DataType, column: Column) -> Self { + let stats = LogicalPlanStats::empty(); + Expr::OuterReferenceColumn(data_type, column, stats) + } + + pub fn placeholder(placeholder: Placeholder) -> Self { + let stats = LogicalPlanStats::empty(); + Expr::Placeholder(placeholder, stats) + } + + pub fn scalar_variable(data_type: DataType, names: Vec) -> Self { + let stats = LogicalPlanStats::empty(); + Expr::ScalarVariable(data_type, names, stats) + } } impl HashNode for Expr { @@ -1676,157 +1956,181 @@ impl HashNode for Expr { fn hash_node(&self, state: &mut H) { mem::discriminant(self).hash(state); match self { - Expr::Alias(Alias { - expr: _expr, - relation, - name, - }) => { + Expr::Alias( + Alias { + expr: _, + relation, + name, + }, + _, + ) => { relation.hash(state); name.hash(state); } - Expr::Column(column) => { + Expr::Column(column, _) => { column.hash(state); } - Expr::ScalarVariable(data_type, name) => { + Expr::ScalarVariable(data_type, name, _) => { data_type.hash(state); name.hash(state); } - Expr::Literal(scalar_value) => { + Expr::Literal(scalar_value, _) => { scalar_value.hash(state); } - Expr::BinaryExpr(BinaryExpr { - left: _left, - op, - right: _right, - }) => { + Expr::BinaryExpr( + BinaryExpr { + left: _, + op, + right: _, + }, + _, + ) => { op.hash(state); } - Expr::Like(Like { - negated, - expr: _expr, - pattern: _pattern, - escape_char, - case_insensitive, - }) - | Expr::SimilarTo(Like { - negated, - expr: _expr, - pattern: _pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr: _, + pattern: _, + escape_char, + case_insensitive, + }, + _, + ) + | Expr::SimilarTo( + Like { + negated, + expr: _, + pattern: _, + escape_char, + case_insensitive, + }, + _, + ) => { negated.hash(state); escape_char.hash(state); case_insensitive.hash(state); } - Expr::Not(_expr) - | Expr::IsNotNull(_expr) - | Expr::IsNull(_expr) - | Expr::IsTrue(_expr) - | Expr::IsFalse(_expr) - | Expr::IsUnknown(_expr) - | Expr::IsNotTrue(_expr) - | Expr::IsNotFalse(_expr) - | Expr::IsNotUnknown(_expr) - | Expr::Negative(_expr) => {} - Expr::Between(Between { - expr: _expr, - negated, - low: _low, - high: _high, - }) => { + Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) => {} + Expr::Between( + Between { + expr: _, + negated, + low: _, + high: _, + }, + _, + ) => { negated.hash(state); } - Expr::Case(Case { - expr: _expr, - when_then_expr: _when_then_expr, - else_expr: _else_expr, - }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { + Expr::Case( + Case { + expr: _, + when_then_expr: _, + else_expr: _, + }, + _, + ) => {} + Expr::Cast(Cast { expr: _, data_type }, _) + | Expr::TryCast(TryCast { expr: _, data_type }, _) => { data_type.hash(state); } - Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + Expr::ScalarFunction(ScalarFunction { func, args: _ }, _) => { func.hash(state); } - Expr::AggregateFunction(AggregateFunction { - func, - args: _args, - distinct, - filter: _filter, - order_by: _order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + args: _, + distinct, + filter: _, + order_by: _, + null_treatment, + }, + _, + ) => { func.hash(state); distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { - fun, - args: _args, - partition_by: _partition_by, - order_by: _order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args: _, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, + _, + ) => { fun.hash(state); window_frame.hash(state); null_treatment.hash(state); } - Expr::InList(InList { - expr: _expr, - list: _list, - negated, - }) => { + Expr::InList( + InList { + expr: _, + list: _, + negated, + }, + _, + ) => { negated.hash(state); } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { subquery.hash(state); negated.hash(state); } - Expr::InSubquery(InSubquery { - expr: _expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr: _, + subquery, + negated, + }, + _, + ) => { subquery.hash(state); negated.hash(state); } - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { subquery.hash(state); } - Expr::Wildcard(wildcard) => { + Expr::Wildcard(wildcard, _) => { wildcard.hash(state); wildcard.hash(state); } - Expr::GroupingSet(grouping_set) => { + Expr::GroupingSet(grouping_set, _) => { mem::discriminant(grouping_set).hash(state); match grouping_set { - GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} - GroupingSet::GroupingSets(_exprs) => {} + GroupingSet::Rollup(_) | GroupingSet::Cube(_) => {} + GroupingSet::GroupingSets(_) => {} } } - Expr::Placeholder(place_holder) => { + Expr::Placeholder(place_holder, _) => { place_holder.hash(state); } - Expr::OuterReferenceColumn(data_type, column) => { + Expr::OuterReferenceColumn(data_type, column, _) => { data_type.hash(state); column.hash(state); } - Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::Unnest(Unnest { expr: _ }, _) => {} }; } } // Modifies expr if it is a placeholder with datatype of right fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { + if let Expr::Placeholder(Placeholder { id: _, data_type }, _) = expr { if data_type.is_none() { let other_dt = other.get_type(schema); match other_dt { @@ -1860,21 +2164,24 @@ impl<'a> Display for SchemaDisplay<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { // The same as Display - Expr::Column(_) - | Expr::Literal(_) + Expr::Column(_, _) + | Expr::Literal(_, _) | Expr::ScalarVariable(..) | Expr::OuterReferenceColumn(..) - | Expr::Placeholder(_) + | Expr::Placeholder(_, _) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => { write!( f, "{}({}{})", @@ -1898,13 +2205,16 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } // Expr is not shown since it is aliased - Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Alias(Alias { name, .. }, _) => write!(f, "{name}"), + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { write!( f, @@ -1923,14 +2233,17 @@ impl<'a> Display for SchemaDisplay<'a> { ) } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { write!(f, "CASE ")?; if let Some(e) = expr { @@ -1953,14 +2266,17 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "END") } // Cast expr is not shown to be consistant with Postgres and Spark - Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { + Expr::Cast(Cast { expr, .. }, _) | Expr::TryCast(TryCast { expr, .. }, _) => { write!(f, "{}", SchemaDisplay(expr)) } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let inlist_name = schema_name_from_exprs(list)?; if *negated { @@ -1969,50 +2285,53 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "{} IN {}", SchemaDisplay(expr), inlist_name) } } - Expr::Exists(Exists { negated: true, .. }) => write!(f, "NOT EXISTS"), - Expr::Exists(Exists { negated: false, .. }) => write!(f, "EXISTS"), - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::Exists(Exists { negated: true, .. }, _) => write!(f, "NOT EXISTS"), + Expr::Exists(Exists { negated: false, .. }, _) => write!(f, "EXISTS"), + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { write!(f, "GROUPING SETS (")?; for exprs in lists_of_exprs.iter() { write!(f, "({})", schema_name_from_exprs(exprs)?)?; } write!(f, ")") } - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) } - Expr::IsNull(expr) => write!(f, "{} IS NULL", SchemaDisplay(expr)), - Expr::IsNotNull(expr) => { + Expr::IsNull(expr, _) => write!(f, "{} IS NULL", SchemaDisplay(expr)), + Expr::IsNotNull(expr, _) => { write!(f, "{} IS NOT NULL", SchemaDisplay(expr)) } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { write!(f, "{} IS UNKNOWN", SchemaDisplay(expr)) } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { write!(f, "{} IS NOT UNKNOWN", SchemaDisplay(expr)) } - Expr::InSubquery(InSubquery { negated: true, .. }) => { + Expr::InSubquery(InSubquery { negated: true, .. }, _) => { write!(f, "NOT IN") } - Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), - Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), - Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), - Expr::IsNotTrue(expr) => { + Expr::InSubquery(InSubquery { negated: false, .. }, _) => write!(f, "IN"), + Expr::IsTrue(expr, _) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), + Expr::IsFalse(expr, _) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), + Expr::IsNotTrue(expr, _) => { write!(f, "{} IS NOT TRUE", SchemaDisplay(expr)) } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { write!(f, "{} IS NOT FALSE", SchemaDisplay(expr)) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { write!( f, "{} {}{} {}", @@ -2028,12 +2347,12 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - Expr::Negative(expr) => write!(f, "(- {})", SchemaDisplay(expr)), - Expr::Not(expr) => write!(f, "NOT {}", SchemaDisplay(expr)), - Expr::Unnest(Unnest { expr }) => { + Expr::Negative(expr, _) => write!(f, "(- {})", SchemaDisplay(expr)), + Expr::Not(expr, _) => write!(f, "NOT {}", SchemaDisplay(expr)), + Expr::Unnest(Unnest { expr }, _) => { write!(f, "UNNEST({})", SchemaDisplay(expr)) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { match func.schema_name(args) { Ok(name) => { write!(f, "{name}") @@ -2043,16 +2362,19 @@ impl<'a> Display for SchemaDisplay<'a> { } } } - Expr::ScalarSubquery(Subquery { subquery, .. }) => { + Expr::ScalarSubquery(Subquery { subquery, .. }, _) => { write!(f, "{}", subquery.schema().field(0).name()) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - .. - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + .. + }, + _, + ) => { write!( f, "{} {} {}", @@ -2070,14 +2392,17 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { write!( f, "{}({})", @@ -2158,12 +2483,12 @@ pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { impl Display for Expr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), - Expr::Column(c) => write!(f, "{c}"), - Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), - Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{v:?}"), - Expr::Case(case) => { + Expr::Alias(Alias { expr, name, .. }, _) => write!(f, "{expr} AS {name}"), + Expr::Column(c, _) => write!(f, "{c}"), + Expr::OuterReferenceColumn(_, c, _) => write!(f, "outer_ref({c})"), + Expr::ScalarVariable(_, var_names, _) => write!(f, "{}", var_names.join(".")), + Expr::Literal(v, _) => write!(f, "{v:?}"), + Expr::Case(case, _) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { write!(f, "{e} ")?; @@ -2176,57 +2501,72 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { write!(f, "CAST({expr} AS {data_type:?})") } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { write!(f, "TRY_CAST({expr} AS {data_type:?})") } - Expr::Not(expr) => write!(f, "NOT {expr}"), - Expr::Negative(expr) => write!(f, "(- {expr})"), - Expr::IsNull(expr) => write!(f, "{expr} IS NULL"), - Expr::IsNotNull(expr) => write!(f, "{expr} IS NOT NULL"), - Expr::IsTrue(expr) => write!(f, "{expr} IS TRUE"), - Expr::IsFalse(expr) => write!(f, "{expr} IS FALSE"), - Expr::IsUnknown(expr) => write!(f, "{expr} IS UNKNOWN"), - Expr::IsNotTrue(expr) => write!(f, "{expr} IS NOT TRUE"), - Expr::IsNotFalse(expr) => write!(f, "{expr} IS NOT FALSE"), - Expr::IsNotUnknown(expr) => write!(f, "{expr} IS NOT UNKNOWN"), - Expr::Exists(Exists { - subquery, - negated: true, - }) => write!(f, "NOT EXISTS ({subquery:?})"), - Expr::Exists(Exists { - subquery, - negated: false, - }) => write!(f, "EXISTS ({subquery:?})"), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated: true, - }) => write!(f, "{expr} NOT IN ({subquery:?})"), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated: false, - }) => write!(f, "{expr} IN ({subquery:?})"), - Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), - Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::ScalarFunction(fun) => { + Expr::Not(expr, _) => write!(f, "NOT {expr}"), + Expr::Negative(expr, _) => write!(f, "(- {expr})"), + Expr::IsNull(expr, _) => write!(f, "{expr} IS NULL"), + Expr::IsNotNull(expr, _) => write!(f, "{expr} IS NOT NULL"), + Expr::IsTrue(expr, _) => write!(f, "{expr} IS TRUE"), + Expr::IsFalse(expr, _) => write!(f, "{expr} IS FALSE"), + Expr::IsUnknown(expr, _) => write!(f, "{expr} IS UNKNOWN"), + Expr::IsNotTrue(expr, _) => write!(f, "{expr} IS NOT TRUE"), + Expr::IsNotFalse(expr, _) => write!(f, "{expr} IS NOT FALSE"), + Expr::IsNotUnknown(expr, _) => write!(f, "{expr} IS NOT UNKNOWN"), + Expr::Exists( + Exists { + subquery, + negated: true, + }, + _, + ) => write!(f, "NOT EXISTS ({subquery:?})"), + Expr::Exists( + Exists { + subquery, + negated: false, + }, + _, + ) => write!(f, "EXISTS ({subquery:?})"), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated: true, + }, + _, + ) => write!(f, "{expr} NOT IN ({subquery:?})"), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated: false, + }, + _, + ) => write!(f, "{expr} IN ({subquery:?})"), + Expr::ScalarSubquery(subquery, _) => write!(f, "({subquery:?})"), + Expr::BinaryExpr(expr, _) => write!(f, "{expr}"), + Expr::ScalarFunction(fun, _) => { fmt_function(f, fun.name(), false, &fun.args, true) } // TODO: use udf's display_name, need to fix the seperator issue, // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { fmt_function(f, &fun.to_string(), false, args, true)?; if let Some(nt) = null_treatment { @@ -2246,15 +2586,18 @@ impl Display for Expr { )?; Ok(()) } - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - ref args, - filter, - order_by, - null_treatment, - .. - }) => { + Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + ref args, + filter, + order_by, + null_treatment, + .. + }, + _, + ) => { fmt_function(f, func.name(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, " {}", nt)?; @@ -2267,25 +2610,31 @@ impl Display for Expr { } Ok(()) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { write!(f, "{expr} NOT BETWEEN {low} AND {high}") } else { write!(f, "{expr} BETWEEN {low} AND {high}") } } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { write!(f, "{expr}")?; let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; if *negated { @@ -2297,13 +2646,16 @@ impl Display for Expr { write!(f, " {op_name} {pattern}") } } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => { write!(f, "{expr}")?; if *negated { write!(f, " NOT")?; @@ -2314,22 +2666,25 @@ impl Display for Expr { write!(f, " SIMILAR TO {pattern}") } } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { if *negated { write!(f, "{expr} NOT IN ([{}])", expr_vec_fmt!(list)) } else { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard(Wildcard { qualifier, options }) => match qualifier { + Expr::Wildcard(Wildcard { qualifier, options }, _) => match qualifier { Some(qualifier) => write!(f, "{qualifier}.*{options}"), None => write!(f, "*{options}"), }, - Expr::GroupingSet(grouping_sets) => match grouping_sets { + Expr::GroupingSet(grouping_sets, _) => match grouping_sets { GroupingSet::Rollup(exprs) => { // ROLLUP (c0, c1, c2) write!(f, "ROLLUP ({})", expr_vec_fmt!(exprs)) @@ -2351,8 +2706,8 @@ impl Display for Expr { ) } }, - Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"), - Expr::Unnest(Unnest { expr }) => { + Expr::Placeholder(Placeholder { id, .. }, _) => write!(f, "{id}"), + Expr::Unnest(Unnest { expr }, _) => { write!(f, "UNNEST({expr})") } } @@ -2381,7 +2736,7 @@ fn fmt_function( /// The name of the column (field) that this `Expr` will produce in the physical plan. /// The difference from [Expr::schema_name] is that top-level columns are unqualified. pub fn physical_name(expr: &Expr) -> Result { - if let Expr::Column(col) = expr { + if let Expr::Column(col, _) = expr { Ok(col.name.clone()) } else { Ok(expr.schema_name().to_string()) @@ -2415,8 +2770,8 @@ mod test { #[test] #[allow(deprecated)] fn format_cast() -> Result<()> { - let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + let expr = Expr::cast(Cast { + expr: Box::new(Expr::literal(ScalarValue::Float32(Some(1.23)))), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; @@ -2445,7 +2800,7 @@ mod test { fn test_collect_expr() -> Result<()> { // single column { - let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); + let expr = &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)); let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c38ffb888f024..9003c11cf3493 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -62,13 +62,13 @@ use std::sync::Arc; /// assert_ne!(c1, c3); /// ``` pub fn col(ident: impl Into) -> Expr { - Expr::Column(ident.into()) + Expr::column(ident.into()) } /// Create an out reference column which hold a reference that has been resolved to a field /// outside of the current plan. pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { - Expr::OuterReferenceColumn(dt, ident.into()) + Expr::outer_reference_column(dt, ident.into()) } /// Create an unqualified column expression from the provided name, without normalizing @@ -90,7 +90,7 @@ pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { /// assert_ne!(c4, c5); /// ``` pub fn ident(name: impl Into) -> Expr { - Expr::Column(Column::from_name(name)) + Expr::column(Column::from_name(name)) } /// Create placeholder value that will be filled in (such as `$1`) @@ -105,7 +105,7 @@ pub fn ident(name: impl Into) -> Expr { /// assert_eq!(p.to_string(), "$0") /// ``` pub fn placeholder(id: impl Into) -> Expr { - Expr::Placeholder(Placeholder { + Expr::placeholder(Placeholder { id: id.into(), data_type: None, }) @@ -121,7 +121,7 @@ pub fn placeholder(id: impl Into) -> Expr { /// assert_eq!(p.to_string(), "*") /// ``` pub fn wildcard() -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), }) @@ -129,7 +129,7 @@ pub fn wildcard() -> Expr { /// Create an '*' [`Expr::Wildcard`] expression with the wildcard options pub fn wildcard_with_options(options: WildcardOptions) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: None, options, }) @@ -146,7 +146,7 @@ pub fn wildcard_with_options(options: WildcardOptions) -> Expr { /// assert_eq!(p.to_string(), "t.*") /// ``` pub fn qualified_wildcard(qualifier: impl Into) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: Some(qualifier.into()), options: WildcardOptions::default(), }) @@ -157,7 +157,7 @@ pub fn qualified_wildcard_with_options( qualifier: impl Into, options: WildcardOptions, ) -> Expr { - Expr::Wildcard(Wildcard { + Expr::wildcard(Wildcard { qualifier: Some(qualifier.into()), options, }) @@ -165,12 +165,12 @@ pub fn qualified_wildcard_with_options( /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) + Expr::binary_expr(BinaryExpr::new(Box::new(left), op, Box::new(right))) } /// Return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::And, Box::new(right), @@ -179,7 +179,7 @@ pub fn and(left: Expr, right: Expr) -> Expr { /// Return a new expression with a logical OR pub fn or(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Or, Box::new(right), @@ -193,7 +193,7 @@ pub fn not(expr: Expr) -> Expr { /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseAnd, Box::new(right), @@ -202,7 +202,7 @@ pub fn bitwise_and(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise OR pub fn bitwise_or(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseOr, Box::new(right), @@ -211,7 +211,7 @@ pub fn bitwise_or(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise XOR pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseXor, Box::new(right), @@ -220,7 +220,7 @@ pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise SHIFT RIGHT pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseShiftRight, Box::new(right), @@ -229,7 +229,7 @@ pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { /// Return a new expression with bitwise SHIFT LEFT pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::BitwiseShiftLeft, Box::new(right), @@ -238,13 +238,13 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { - Expr::InList(InList::new(Box::new(expr), list, negated)) + Expr::_in_list(InList::new(Box::new(expr), list, negated)) } /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::Exists(Exists { + Expr::exists(Exists { subquery: Subquery { subquery, outer_ref_columns, @@ -256,7 +256,7 @@ pub fn exists(subquery: Arc) -> Expr { /// Create a NOT EXISTS subquery expression pub fn not_exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::Exists(Exists { + Expr::exists(Exists { subquery: Subquery { subquery, outer_ref_columns, @@ -268,7 +268,7 @@ pub fn not_exists(subquery: Arc) -> Expr { /// Create an IN subquery expression pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::InSubquery(InSubquery::new( + Expr::in_subquery(InSubquery::new( Box::new(expr), Subquery { subquery, @@ -281,7 +281,7 @@ pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { /// Create a NOT IN subquery expression pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::InSubquery(InSubquery::new( + Expr::in_subquery(InSubquery::new( Box::new(expr), Subquery { subquery, @@ -294,7 +294,7 @@ pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { /// Create a scalar subquery expression pub fn scalar_subquery(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); - Expr::ScalarSubquery(Subquery { + Expr::scalar_subquery(Subquery { subquery, outer_ref_columns, }) @@ -302,62 +302,62 @@ pub fn scalar_subquery(subquery: Arc) -> Expr { /// Create a grouping set pub fn grouping_set(exprs: Vec>) -> Expr { - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) + Expr::grouping_set(GroupingSet::GroupingSets(exprs)) } /// Create a grouping set for all combination of `exprs` pub fn cube(exprs: Vec) -> Expr { - Expr::GroupingSet(GroupingSet::Cube(exprs)) + Expr::grouping_set(GroupingSet::Cube(exprs)) } /// Create a grouping set for rollup pub fn rollup(exprs: Vec) -> Expr { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) + Expr::grouping_set(GroupingSet::Rollup(exprs)) } /// Create a cast expression pub fn cast(expr: Expr, data_type: DataType) -> Expr { - Expr::Cast(Cast::new(Box::new(expr), data_type)) + Expr::cast(Cast::new(Box::new(expr), data_type)) } /// Create a try cast expression pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { - Expr::TryCast(TryCast::new(Box::new(expr), data_type)) + Expr::try_cast(TryCast::new(Box::new(expr), data_type)) } /// Create is null expression pub fn is_null(expr: Expr) -> Expr { - Expr::IsNull(Box::new(expr)) + Expr::_is_null(Box::new(expr)) } /// Create is true expression pub fn is_true(expr: Expr) -> Expr { - Expr::IsTrue(Box::new(expr)) + Expr::_is_true(Box::new(expr)) } /// Create is not true expression pub fn is_not_true(expr: Expr) -> Expr { - Expr::IsNotTrue(Box::new(expr)) + Expr::_is_not_true(Box::new(expr)) } /// Create is false expression pub fn is_false(expr: Expr) -> Expr { - Expr::IsFalse(Box::new(expr)) + Expr::_is_false(Box::new(expr)) } /// Create is not false expression pub fn is_not_false(expr: Expr) -> Expr { - Expr::IsNotFalse(Box::new(expr)) + Expr::_is_not_false(Box::new(expr)) } /// Create is unknown expression pub fn is_unknown(expr: Expr) -> Expr { - Expr::IsUnknown(Box::new(expr)) + Expr::_is_unknown(Box::new(expr)) } /// Create is not unknown expression pub fn is_not_unknown(expr: Expr) -> Expr { - Expr::IsNotUnknown(Box::new(expr)) + Expr::_is_not_unknown(Box::new(expr)) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. @@ -372,7 +372,7 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { /// Create a Unnest expression pub fn unnest(expr: Expr) -> Expr { - Expr::Unnest(Unnest { + Expr::unnest(Unnest { expr: Box::new(expr), }) } @@ -677,17 +677,17 @@ impl WindowUDFImpl for SimpleWindowUDF { pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::literal(ScalarValue::IntervalYearMonth(interval)) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::literal(ScalarValue::IntervalDayTime(interval)) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::literal(ScalarValue::IntervalMonthDayNano(interval)) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] @@ -812,7 +812,7 @@ impl ExprFuncBuilder { udaf.filter = filter.map(Box::new); udaf.distinct = distinct; udaf.null_treatment = null_treatment; - Expr::AggregateFunction(udaf) + Expr::aggregate_function(udaf) } ExprFuncKind::Window(mut udwf) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); @@ -821,7 +821,7 @@ impl ExprFuncBuilder { udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); udwf.null_treatment = null_treatment; - Expr::WindowFunction(udwf) + Expr::window_function(udwf) } }; @@ -871,10 +871,10 @@ impl ExprFunctionExt for ExprFuncBuilder { impl ExprFunctionExt for Expr { fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) } _ => ExprFuncBuilder::new(None), @@ -886,7 +886,7 @@ impl ExprFunctionExt for Expr { } fn filter(self, filter: Expr) -> ExprFuncBuilder { match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.filter = Some(filter); @@ -897,7 +897,7 @@ impl ExprFunctionExt for Expr { } fn distinct(self) -> ExprFuncBuilder { match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.distinct = true; @@ -911,10 +911,10 @@ impl ExprFunctionExt for Expr { null_treatment: impl Into>, ) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => { + Expr::AggregateFunction(udaf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) } _ => ExprFuncBuilder::new(None), @@ -927,7 +927,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); builder.partition_by = Some(partition_by); builder @@ -938,7 +938,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { - Expr::WindowFunction(udwf) => { + Expr::WindowFunction(udwf, _) => { let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); builder.window_frame = Some(window_frame); builder diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index b944428977c4c..051a279688901 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -64,9 +64,9 @@ pub trait FunctionRewrite: Debug { pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { expr.transform(|expr| { Ok({ - if let Expr::Column(c) = expr { + if let Expr::Column(c, _) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) + Transformed::yes(Expr::column(col)) } else { Transformed::no(expr) } @@ -82,21 +82,21 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( using_columns: &[HashSet], ) -> Result { // Normalize column inside Unnest - if let Expr::Unnest(Unnest { expr }) = expr { + if let Expr::Unnest(Unnest { expr }, _) = expr { let e = normalize_col_with_schemas_and_ambiguity_check( expr.as_ref().clone(), schemas, using_columns, )?; - return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); + return Ok(Expr::unnest(Unnest { expr: Box::new(e) })); } expr.transform(|expr| { Ok({ - if let Expr::Column(c) = expr { + if let Expr::Column(c, _) = expr { let col = c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) + Transformed::yes(Expr::column(col)) } else { Transformed::no(expr) } @@ -135,9 +135,9 @@ pub fn normalize_sorts( pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { expr.transform(|expr| { Ok({ - if let Expr::Column(c) = &expr { + if let Expr::Column(c, _) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + Some(new_c) => Transformed::yes(Expr::column((*new_c).to_owned())), None => Transformed::no(expr), } } else { @@ -156,12 +156,12 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul pub fn unnormalize_col(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ - if let Expr::Column(c) = expr { + if let Expr::Column(c, _) = expr { let col = Column { relation: None, name: c.name, }; - Transformed::yes(Expr::Column(col)) + Transformed::yes(Expr::column(col)) } else { Transformed::no(expr) } @@ -177,11 +177,11 @@ pub fn create_col_from_scalar_expr( subqry_alias: String, ) -> Result { match scalar_expr { - Expr::Alias(Alias { name, .. }) => Ok(Column::new( + Expr::Alias(Alias { name, .. }, _) => Ok(Column::new( Some::(subqry_alias.into()), name, )), - Expr::Column(Column { relation: _, name }) => Ok(Column::new( + Expr::Column(Column { relation: _, name }, _) => Ok(Column::new( Some::(subqry_alias.into()), name, )), @@ -206,8 +206,8 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ - if let Expr::OuterReferenceColumn(_, col) = expr { - Transformed::yes(Expr::Column(col)) + if let Expr::OuterReferenceColumn(_, col, _) = expr { + Transformed::yes(Expr::column(col)) } else { Transformed::no(expr) } @@ -225,10 +225,10 @@ pub fn coerce_plan_expr_for_schema( ) -> Result { match plan { // special case Projection to avoid adding multiple projections - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?; let projection = Projection::try_new(new_exprs, input)?; - Ok(LogicalPlan::Projection(projection)) + Ok(LogicalPlan::projection(projection)) } _ => { let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); @@ -236,7 +236,7 @@ pub fn coerce_plan_expr_for_schema( let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); if add_project { let projection = Projection::try_new(new_exprs, Arc::new(plan))?; - Ok(LogicalPlan::Projection(projection)) + Ok(LogicalPlan::projection(projection)) } else { Ok(plan) } @@ -256,7 +256,7 @@ fn coerce_exprs_for_schema( let new_type = dst_schema.field(idx).data_type(); if new_type != &expr.get_type(src_schema)? { match expr { - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { expr, name, .. }, _) => { Ok(expr.cast_to(new_type, src_schema)?.alias(name)) } Expr::Wildcard { .. } => Ok(expr), @@ -273,7 +273,7 @@ fn coerce_exprs_for_schema( #[inline] pub fn unalias(expr: Expr) -> Expr { match expr { - Expr::Alias(Alias { expr, .. }) => unalias(*expr), + Expr::Alias(Alias { expr, .. }, _) => unalias(*expr), _ => expr, } } @@ -310,11 +310,11 @@ impl NamePreserver { // so there is no need to preserve expression names to prevent a schema change. use_alias: !matches!( plan, - LogicalPlan::Filter(_) - | LogicalPlan::Join(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) + LogicalPlan::Filter(_, _) + | LogicalPlan::Join(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Statement(_, _) ), } } @@ -387,7 +387,7 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), _) => { let utf8_val = if utf8_val == "foo" { "bar".to_string() } else { @@ -514,7 +514,7 @@ mod test { // cast data types test_rewrite( col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + Expr::cast(Cast::new(Box::new(col("a")), DataType::Int32)), ); // change literal type from i32 to i64 @@ -522,12 +522,12 @@ mod test { // test preserve qualifier test_rewrite( - Expr::Column(Column::new(Some("test"), "a")), - Expr::Column(Column::new_unqualified("test.a")), + Expr::column(Column::new(Some("test"), "a")), + Expr::column(Column::new_unqualified("test.a")), ); test_rewrite( - Expr::Column(Column::new_unqualified("test.a")), - Expr::Column(Column::new(Some("test"), "a")), + Expr::column(Column::new_unqualified("test.a")), + Expr::column(Column::new(Some("test"), "a")), ); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index f0d3d8fcd0c15..ca1cf1c4638ea 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -78,7 +78,7 @@ fn rewrite_in_terms_of_projection( // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); - let col = Expr::Column(Column::new(qualifier, field_name)); + let col = Expr::column(Column::new(qualifier, field_name)); return Ok(Transformed::yes(col)); } @@ -98,7 +98,7 @@ fn rewrite_in_terms_of_projection( // for a column with the same "MIN(C2)", so translate there let name = normalized_expr.schema_name().to_string(); - let search_col = Expr::Column(Column { + let search_col = Expr::column(Column { relation: None, name, }); @@ -107,14 +107,16 @@ fn rewrite_in_terms_of_projection( if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { - expr: Box::new(found), - data_type, - }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::Cast(Cast { expr: _, data_type }, _) => Expr::cast(Cast { expr: Box::new(found), data_type, }), + Expr::TryCast(TryCast { expr: _, data_type }, _) => { + Expr::try_cast(TryCast { + expr: Box::new(found), + data_type, + }) + } _ => found, })); } @@ -128,7 +130,7 @@ fn rewrite_in_terms_of_projection( /// so avg(c) as average will match avgc fn expr_match(needle: &Expr, expr: &Expr) -> bool { // check inside aliases - if let Expr::Alias(Alias { expr, .. }) = &expr { + if let Expr::Alias(Alias { expr, .. }, _) = &expr { expr.as_ref() == needle } else { expr == needle diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b1a461eca41db..461b23f8b783a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -103,19 +103,19 @@ impl ExprSchemable for Expr { #[recursive] fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + Expr::Alias(Alias { expr, name, .. }, _) => match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }, _) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), Some(dt) => Ok(dt.clone()), }, _ => expr.get_type(schema), }, - Expr::Negative(expr) => expr.get_type(schema), - Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), - Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => { + Expr::Negative(expr, _) => expr.get_type(schema), + Expr::Column(c, _) => Ok(schema.data_type(c)?.clone()), + Expr::OuterReferenceColumn(ty, _, _) => Ok(ty.clone()), + Expr::ScalarVariable(ty, _, _) => Ok(ty.clone()), + Expr::Literal(l, _) => Ok(l.data_type()), + Expr::Case(case, _) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; if !then_type.is_null() { @@ -126,9 +126,9 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::Unnest(Unnest { expr }) => { + Expr::Cast(Cast { data_type, .. }, _) + | Expr::TryCast(TryCast { data_type, .. }, _) => Ok(data_type.clone()), + Expr::Unnest(Unnest { expr }, _) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list match arg_data_type { @@ -146,7 +146,7 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) @@ -170,10 +170,10 @@ impl ExprSchemable for Expr { // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } - Expr::WindowFunction(window_function) => self + Expr::WindowFunction(window_function, _) => self .data_type_and_nullable_with_window_function(schema, window_function) .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func, args, .. }, _) => { let data_types = args .iter() .map(|e| e.get_type(schema)) @@ -192,29 +192,32 @@ impl ExprSchemable for Expr { })?; Ok(func.return_type(&new_types)?) } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Exists { .. } - | Expr::InSubquery(_) + Expr::Not(_, _) + | Expr::IsNull(_, _) + | Expr::Exists(_, _) + | Expr::InSubquery(_, _) | Expr::Between { .. } | Expr::InList { .. } - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), - Expr::ScalarSubquery(subquery) => { + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) => Ok(DataType::Boolean), + Expr::ScalarSubquery(subquery, _) => { Ok(subquery.subquery.schema().field(0).data_type().clone()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + ref op, + }, + _, + ) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), - Expr::Placeholder(Placeholder { data_type, .. }) => { + Expr::Placeholder(Placeholder { data_type, .. }, _) => { data_type.clone().ok_or_else(|| { plan_datafusion_err!( "Placeholder type could not be resolved. Make sure that the \ @@ -224,7 +227,7 @@ impl ExprSchemable for Expr { }) } Expr::Wildcard { .. } => Ok(DataType::Null), - Expr::GroupingSet(_) => { + Expr::GroupingSet(_, _) => { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } @@ -244,11 +247,11 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { - expr.nullable(input_schema) - } + Expr::Alias(Alias { expr, .. }, _) + | Expr::Not(expr, _) + | Expr::Negative(expr, _) => expr.nullable(input_schema), - Expr::InList(InList { expr, list, .. }) => { + Expr::InList(InList { expr, list, .. }, _) => { // Avoid inspecting too many expressions. const MAX_INSPECT_LIMIT: usize = 6; // Stop if a nullable expression is found or an error occurs. @@ -271,16 +274,19 @@ impl ExprSchemable for Expr { }) } - Expr::Between(Between { - expr, low, high, .. - }) => Ok(expr.nullable(input_schema)? + Expr::Between( + Between { + expr, low, high, .. + }, + _, + ) => Ok(expr.nullable(input_schema)? || low.nullable(input_schema)? || high.nullable(input_schema)?), - Expr::Column(c) => input_schema.nullable(c), - Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), - Expr::Case(case) => { + Expr::Column(c, _) => input_schema.nullable(c), + Expr::OuterReferenceColumn(_, _, _) => Ok(true), + Expr::Literal(value, _) => Ok(value.is_null()), + Expr::Case(case, _) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr @@ -297,47 +303,50 @@ impl ExprSchemable for Expr { Ok(true) } } - Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::Cast(Cast { expr, .. }, _) => expr.nullable(input_schema), + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { Ok(func.is_nullable(args, input_schema)) } - Expr::AggregateFunction(AggregateFunction { func, .. }) => { + Expr::AggregateFunction(AggregateFunction { func, .. }, _) => { Ok(func.is_nullable()) } - Expr::WindowFunction(window_function) => self + Expr::WindowFunction(window_function, _) => self .data_type_and_nullable_with_window_function( input_schema, window_function, ) .map(|(_, nullable)| nullable), - Expr::ScalarVariable(_, _) + Expr::ScalarVariable(_, _, _) | Expr::TryCast { .. } - | Expr::Unnest(_) - | Expr::Placeholder(_) => Ok(true), - Expr::IsNull(_) - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) + | Expr::Unnest(_, _) + | Expr::Placeholder(_, _) => Ok(true), + Expr::IsNull(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) | Expr::Exists { .. } => Ok(false), - Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarSubquery(subquery) => { + Expr::InSubquery(InSubquery { expr, .. }, _) => expr.nullable(input_schema), + Expr::ScalarSubquery(subquery, _) => { Ok(subquery.subquery.schema().field(0).is_nullable()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - .. - }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Like(Like { expr, pattern, .. }) - | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + .. + }, + _, + ) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::Like(Like { expr, pattern, .. }, _) + | Expr::SimilarTo(Like { expr, pattern, .. }, _) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } Expr::Wildcard { .. } => Ok(false), - Expr::GroupingSet(_) => { + Expr::GroupingSet(_, _) => { // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) @@ -347,9 +356,9 @@ impl ExprSchemable for Expr { fn metadata(&self, schema: &dyn ExprSchema) -> Result> { match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), + Expr::Column(c, _) => Ok(schema.metadata(c)?.clone()), + Expr::Alias(Alias { expr, .. }, _) => expr.metadata(schema), + Expr::Cast(Cast { expr, .. }, _) => expr.metadata(schema), _ => Ok(HashMap::new()), } } @@ -369,8 +378,8 @@ impl ExprSchemable for Expr { schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + Expr::Alias(Alias { expr, name, .. }, _) => match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }, _) => match &data_type { None => schema .data_type_and_nullable(&Column::from_name(name)) .map(|(d, n)| (d.clone(), n)), @@ -378,36 +387,39 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema + Expr::Negative(expr, _) => expr.data_type_and_nullable(schema), + Expr::Column(c, _) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), - Expr::IsNull(_) - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) + Expr::OuterReferenceColumn(ty, _, _) => Ok((ty.clone(), true)), + Expr::ScalarVariable(ty, _, _) => Ok((ty.clone(), true)), + Expr::Literal(l, _) => Ok((l.data_type(), l.is_null())), + Expr::IsNull(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( + Expr::ScalarSubquery(subquery, _) => Ok(( subquery.subquery.schema().field(0).data_type().clone(), subquery.subquery.schema().field(0).is_nullable(), )), - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + ref right, + ref op, + }, + _, + ) => { let left = left.data_type_and_nullable(schema)?; let right = right.data_type_and_nullable(schema)?; Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) } - Expr::WindowFunction(window_function) => { + Expr::WindowFunction(window_function, _) => { self.data_type_and_nullable_with_window_function(schema, window_function) } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), @@ -448,10 +460,10 @@ impl ExprSchemable for Expr { if can_cast_types(&this_type, cast_to_type) { match self { - Expr::ScalarSubquery(subquery) => { - Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) - } - _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), + Expr::ScalarSubquery(subquery, _) => Ok(Expr::scalar_subquery( + cast_subquery(subquery, cast_to_type)?, + )), + _ => Ok(Expr::cast(Cast::new(Box::new(self), cast_to_type.clone()))), } } else { plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") @@ -538,19 +550,19 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { + LogicalPlan::Projection(projection, _) => { let cast_expr = projection.expr[0] .clone() .cast_to(cast_to_type, projection.input.schema())?; - LogicalPlan::Projection(Projection::try_new( + LogicalPlan::projection(Projection::try_new( vec![cast_expr], Arc::clone(&projection.input), )?) } _ => { - let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0))) + let cast_expr = Expr::column(Column::from(plan.schema().qualified_field(0))) .cast_to(cast_to_type, subquery.subquery.schema())?; - LogicalPlan::Projection(Projection::try_new( + LogicalPlan::projection(Projection::try_new( vec![cast_expr], subquery.subquery, )?) diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c7..e14a178a029f9 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,37 +43,37 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::literal(ScalarValue::from(self.as_ref())) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::literal(ScalarValue::Binary(Some((*self).to_owned()))) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::literal(ScalarValue::Binary(Some((*self).to_owned()))) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::literal(self.clone()) } } @@ -82,7 +82,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::literal(ScalarValue::$SCALAR(Some(self.clone()))) } } }; @@ -93,7 +93,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::literal(ScalarValue::$SCALAR(Some(self.get()))) } } }; @@ -104,7 +104,7 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( + Expr::literal(ScalarValue::TimestampNanosecond( Some((self.clone()).into()), None, )) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 90235e3f84c48..9ae7ab4911733 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -132,7 +132,7 @@ impl LogicalPlanBuilder { /// /// `produce_one_row` set to true means this empty node needs to produce a placeholder row. pub fn empty(produce_one_row: bool) -> Self { - Self::new(LogicalPlan::EmptyRelation(EmptyRelation { + Self::new(LogicalPlan::empty_relation(EmptyRelation { produce_one_row, schema: DFSchemaRef::new(DFSchema::empty()), })) @@ -164,7 +164,7 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; - Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(Self::from(LogicalPlan::recursive_query(RecursiveQuery { name, static_term: self.plan, recursive_term: Arc::new(coerced_recursive_term), @@ -307,8 +307,8 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in field_types.iter().enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + if let Expr::Literal(ScalarValue::Null, _) = row[j] { + row[j] = Expr::literal(ScalarValue::try_from(field_type)?); } else { row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } @@ -326,7 +326,7 @@ impl LogicalPlanBuilder { let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); - Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) + Ok(Self::new(LogicalPlan::values(Values { schema, values }))) } /// Convert a table provider into a builder with a TableScan @@ -377,7 +377,7 @@ impl LogicalPlanBuilder { options: HashMap, partition_by: Vec, ) -> Result { - Ok(Self::new(LogicalPlan::Copy(CopyTo { + Ok(Self::new(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url, partition_by, @@ -395,7 +395,7 @@ impl LogicalPlanBuilder { ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( + Ok(Self::new(LogicalPlan::dml(DmlStatement::new( table_name.into(), table_schema, WriteOp::Insert(insert_op), @@ -411,7 +411,7 @@ impl LogicalPlanBuilder { filters: Vec, ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, None) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Self::new) } @@ -424,7 +424,7 @@ impl LogicalPlanBuilder { fetch: Option, ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, fetch) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Self::new) } @@ -477,7 +477,7 @@ impl LogicalPlanBuilder { pub fn select(self, indices: impl IntoIterator) -> Result { let exprs: Vec<_> = indices .into_iter() - .map(|x| Expr::Column(Column::from(self.plan.schema().qualified_field(x)))) + .map(|x| Expr::column(Column::from(self.plan.schema().qualified_field(x)))) .collect(); self.project(exprs) } @@ -486,7 +486,7 @@ impl LogicalPlanBuilder { pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; Filter::try_new(expr, self.plan) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) .map(Self::new) } @@ -494,13 +494,13 @@ impl LogicalPlanBuilder { pub fn having(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; Filter::try_new_with_having(expr, self.plan) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { - Ok(Self::new(LogicalPlan::Statement(Statement::Prepare( + Ok(Self::new(LogicalPlan::statement(Statement::Prepare( Prepare { name, data_types, @@ -529,7 +529,7 @@ impl LogicalPlanBuilder { /// /// Similar to `limit` but uses expressions for `skip` and `fetch` pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { - Ok(Self::new(LogicalPlan::Limit(Limit { + Ok(Self::new(LogicalPlan::limit(Limit { skip: skip.map(Box::new), fetch: fetch.map(Box::new), input: self.plan, @@ -575,14 +575,17 @@ impl LogicalPlanBuilder { is_distinct: bool, ) -> Result { match curr_plan { - LogicalPlan::Projection(Projection { - input, - mut expr, - schema: _, - }) if missing_cols.iter().all(|c| input.schema().has_column(c)) => { + LogicalPlan::Projection( + Projection { + input, + mut expr, + schema: _, + }, + _, + ) if missing_cols.iter().all(|c| input.schema().has_column(c)) => { let mut missing_exprs = missing_cols .iter() - .map(|c| normalize_col(Expr::Column(c.clone()), &input)) + .map(|c| normalize_col(Expr::column(c.clone()), &input)) .collect::>>()?; // Do not let duplicate columns to be added, some of the @@ -597,7 +600,7 @@ impl LogicalPlanBuilder { } _ => { let is_distinct = - is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_)); + is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_, _)); let new_inputs = curr_plan .inputs() .into_iter() @@ -632,7 +635,7 @@ impl LogicalPlanBuilder { // As described in https://github.com/apache/datafusion/issues/5293 let all_aliases = missing_exprs.iter().all(|e| { projection_exprs.iter().any(|proj_expr| { - if let Expr::Alias(Alias { expr, .. }) = proj_expr { + if let Expr::Alias(Alias { expr, .. }, _) = proj_expr { e == expr.as_ref() } else { false @@ -696,7 +699,7 @@ impl LogicalPlanBuilder { })?; if missing_cols.is_empty() { - return Ok(Self::new(LogicalPlan::Sort(Sort { + return Ok(Self::new(LogicalPlan::sort(Sort { expr: normalize_sorts(sorts, &self.plan)?, input: self.plan, fetch, @@ -704,7 +707,7 @@ impl LogicalPlanBuilder { } // remove pushed down sort columns - let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); + let new_expr = schema.columns().into_iter().map(Expr::column).collect(); let is_distinct = false; let plan = Self::add_missing_columns( @@ -712,14 +715,14 @@ impl LogicalPlanBuilder { &missing_cols, is_distinct, )?; - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch, }); Projection::try_new(new_expr, Arc::new(sort_plan)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Self::new) } @@ -733,14 +736,14 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); let right_plan: LogicalPlan = plan; - Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::new(LogicalPlan::distinct(Distinct::All(Arc::new( union(left_plan, right_plan)?, ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan)))) + Ok(Self::new(LogicalPlan::distinct(Distinct::All(self.plan)))) } /// Project first values of the specified expression list according to the provided @@ -751,7 +754,7 @@ impl LogicalPlanBuilder { select_expr: Vec, sort_expr: Option>, ) -> Result { - Ok(Self::new(LogicalPlan::Distinct(Distinct::On( + Ok(Self::new(LogicalPlan::distinct(Distinct::On( DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, )))) } @@ -962,12 +965,12 @@ impl LogicalPlanBuilder { let on = left_keys .into_iter() .zip(right_keys) - .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) + .map(|(l, r)| (Expr::column(l), Expr::column(r))) .collect(); let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on, @@ -1006,17 +1009,17 @@ impl LogicalPlanBuilder { && right.schema().has_column(r) && can_hash(self.plan.schema().field_from_column(l)?.data_type()) { - join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); + join_on.push((Expr::column(l.clone()), Expr::column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) && can_hash(self.plan.schema().field_from_column(r)?.data_type()) { - join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); + join_on.push((Expr::column(r.clone()), Expr::column(l.clone()))); } else { let expr = binary_expr( - Expr::Column(l.clone()), + Expr::column(l.clone()), Operator::Eq, - Expr::Column(r.clone()), + Expr::column(r.clone()), ); match filters { None => filters = Some(expr), @@ -1031,7 +1034,7 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: join_on, @@ -1048,7 +1051,7 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: vec![], @@ -1062,7 +1065,7 @@ impl LogicalPlanBuilder { /// Repartition pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { - Ok(Self::new(LogicalPlan::Repartition(Repartition { + Ok(Self::new(LogicalPlan::repartition(Repartition { input: self.plan, partitioning_scheme, }))) @@ -1075,7 +1078,7 @@ impl LogicalPlanBuilder { ) -> Result { let window_expr = normalize_cols(window_expr, &self.plan)?; validate_unique_names("Windows", &window_expr)?; - Ok(Self::new(LogicalPlan::Window(Window::try_new( + Ok(Self::new(LogicalPlan::window(Window::try_new( window_expr, self.plan, )?))) @@ -1095,7 +1098,7 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(self.plan, group_expr, aggr_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Self::new) } @@ -1110,7 +1113,7 @@ impl LogicalPlanBuilder { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(Self::new(LogicalPlan::Analyze(Analyze { + Ok(Self::new(LogicalPlan::analyze(Analyze { verbose, input: self.plan, schema, @@ -1119,7 +1122,7 @@ impl LogicalPlanBuilder { let stringified_plans = vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(Self::new(LogicalPlan::Explain(Explain { + Ok(Self::new(LogicalPlan::explain(Explain { verbose, plan: self.plan, stringified_plans, @@ -1266,7 +1269,7 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::new(LogicalPlan::Join(Join { + Ok(Self::new(LogicalPlan::join(Join { left: self.plan, right: Arc::new(right), on: join_key_pairs, @@ -1471,7 +1474,7 @@ pub fn add_group_by_exprs_from_dependencies( get_target_functional_dependencies(schema, &group_by_field_names) { for idx in target_indices { - let expr = Expr::Column(Column::from(schema.qualified_field(idx))); + let expr = Expr::column(Column::from(schema.qualified_field(idx))); let expr_name = expr.schema_name().to_string(); if !group_by_field_names.contains(&expr_name) { group_by_field_names.push(expr_name); @@ -1534,7 +1537,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result, ) -> Result { - SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::subquery_alias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1641,7 +1644,7 @@ pub fn wrap_projection_for_join_if_necessary( // // then a and cast(a as int) will use the same field name - `a` in projection schema. // https://github.com/apache/datafusion/issues/4478 - if matches!(key, Expr::Cast(_)) || matches!(key, Expr::TryCast(_)) { + if matches!(key, Expr::Cast(_, _)) || matches!(key, Expr::TryCast(_, _)) { let alias = format!("{key}"); key.clone().alias(alias) } else { @@ -1650,13 +1653,15 @@ pub fn wrap_projection_for_join_if_necessary( }) .collect::>(); - let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_))); + let need_project = join_keys + .iter() + .any(|key| !matches!(key, Expr::Column(_, _))); let plan = if need_project { // Include all columns from the input and extend them with the join keys let mut projection = input_schema .columns() .into_iter() - .map(Expr::Column) + .map(Expr::column) .collect::>(); let join_key_items = alias_join_keys .iter() @@ -1948,7 +1953,7 @@ pub fn unnest_with_options( let deps = input_schema.functional_dependencies().clone(); let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); - Ok(LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::unnest(Unnest { input: Arc::new(input), exec_columns: columns_to_unnest, list_type_columns: list_columns, diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 8c64a017988e9..bc9ebf6752bba 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -25,6 +25,7 @@ use std::{ }; use crate::expr::Sort; +use crate::logical_plan::tree_node::LogicalPlanStats; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; use datafusion_common::{ @@ -188,6 +189,46 @@ impl DdlStatement { } Wrapper(self) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + DdlStatement::CreateExternalTable(CreateExternalTable { + order_exprs, + column_defaults, + .. + }) => order_exprs + .iter() + .flatten() + .map(|s| &s.expr) + .chain(column_defaults.values()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, + column_defaults, + .. + }) => column_defaults + .iter() + .map(|(_, e)| e) + .fold(input.stats(), |s, e| s.merge(e.stats())), + DdlStatement::CreateView(CreateView { input, .. }) => input.stats(), + DdlStatement::CreateIndex(CreateIndex { columns, .. }) => columns + .iter() + .map(|s| &s.expr) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateFunction(CreateFunction { args, params, .. }) => args + .iter() + .flatten() + .flat_map(|a| a.default_expr.as_slice()) + .chain(params.function_body.as_slice()) + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + DdlStatement::CreateCatalogSchema(_) + | DdlStatement::CreateCatalog(_) + | DdlStatement::DropTable(_) + | DdlStatement::DropView(_) + | DdlStatement::DropCatalogSchema(_) + | DdlStatement::DropFunction(_) => LogicalPlanStats::empty(), + } + } } /// Creates an external table. diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index b808defcb959c..2f3ba982feaa1 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -312,18 +312,18 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { /// Converts a logical plan node to a json object. fn to_json_value(node: &LogicalPlan) -> serde_json::Value { match node { - LogicalPlan::EmptyRelation(_) => { + LogicalPlan::EmptyRelation(_, _) => { json!({ "Node Type": "EmptyRelation", }) } - LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { + LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }, _) => { json!({ "Node Type": "RecursiveQuery", "Is Distinct": is_distinct, }) } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { ref values, .. }, _) => { let str_values = values .iter() // limit to only 5 values to avoid horrible display @@ -347,13 +347,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Values": values_str }) } - LogicalPlan::TableScan(TableScan { - ref source, - ref table_name, - ref filters, - ref fetch, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + ref source, + ref table_name, + ref filters, + ref fetch, + .. + }, + _, + ) => { let mut object = json!({ "Node Type": "TableScan", "Relation Name": table_name.table(), @@ -407,26 +410,29 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { ref expr, .. }, _) => { json!({ "Node Type": "Projection", "Expressions": expr.iter().map(|e| e.to_string()).collect::>() }) } - LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + LogicalPlan::Dml(DmlStatement { table_name, op, .. }, _) => { json!({ "Node Type": "Projection", "Operation": op.name(), "Table Name": table_name.table() }) } - LogicalPlan::Copy(CopyTo { - input: _, - output_url, - file_type, - partition_by: _, - options, - }) => { + LogicalPlan::Copy( + CopyTo { + input: _, + output_url, + file_type, + partition_by: _, + options, + }, + _, + ) => { let op_str = options .iter() .map(|(k, v)| format!("{}={}", k, v)) @@ -439,41 +445,50 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Options": op_str }) } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { json!({ "Node Type": "Ddl", "Operation": format!("{}", ddl.display()) }) } - LogicalPlan::Filter(Filter { - predicate: ref expr, - .. - }) => { + LogicalPlan::Filter( + Filter { + predicate: ref expr, + .. + }, + _, + ) => { json!({ "Node Type": "Filter", "Condition": format!("{}", expr) }) } - LogicalPlan::Window(Window { - ref window_expr, .. - }) => { + LogicalPlan::Window( + Window { + ref window_expr, .. + }, + _, + ) => { json!({ "Node Type": "WindowAggr", "Expressions": expr_vec_fmt!(window_expr) }) } - LogicalPlan::Aggregate(Aggregate { - ref group_expr, - ref aggr_expr, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + ref group_expr, + ref aggr_expr, + .. + }, + _, + ) => { json!({ "Node Type": "Aggregate", "Group By": expr_vec_fmt!(group_expr), "Aggregates": expr_vec_fmt!(aggr_expr) }) } - LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + LogicalPlan::Sort(Sort { expr, fetch, .. }, _) => { let mut object = json!({ "Node Type": "Sort", "Sort Key": expr_vec_fmt!(expr), @@ -485,13 +500,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::Join(Join { - on: ref keys, - filter, - join_constraint, - join_type, - .. - }) => { + LogicalPlan::Join( + Join { + on: ref keys, + filter, + join_constraint, + join_type, + .. + }, + _, + ) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); let filter_expr = filter @@ -505,10 +523,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { json!({ "Node Type": "Repartition", @@ -537,11 +558,14 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit( + Limit { + ref skip, + ref fetch, + .. + }, + _, + ) => { let mut object = serde_json::json!( { "Node Type": "Limit", @@ -555,24 +579,24 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }; object } - LogicalPlan::Subquery(Subquery { .. }) => { + LogicalPlan::Subquery(Subquery { .. }, _) => { json!({ "Node Type": "Subquery" }) } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }, _) => { json!({ "Node Type": "Subquery", "Alias": alias.table(), }) } - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { json!({ "Node Type": "Statement", "Statement": format!("{}", statement.display()) }) } - LogicalPlan::Distinct(distinct) => match distinct { + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(_) => { json!({ "Node Type": "DistinctAll" @@ -607,28 +631,31 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "Analyze" }) } - LogicalPlan::Union(_) => { + LogicalPlan::Union(_, _) => { json!({ "Node Type": "Union" }) } - LogicalPlan::Extension(e) => { + LogicalPlan::Extension(e, _) => { json!({ "Node Type": e.node.name(), "Detail": format!("{:?}", e.node) }) } - LogicalPlan::DescribeTable(DescribeTable { .. }) => { + LogicalPlan::DescribeTable(DescribeTable { .. }, _) => { json!({ "Node Type": "DescribeTable" }) } - LogicalPlan::Unnest(Unnest { - input: plan, - list_type_columns: list_col_indices, - struct_type_columns: struct_col_indices, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + input: plan, + list_type_columns: list_col_indices, + struct_type_columns: struct_col_indices, + .. + }, + _, + ) => { let input_columns = plan.schema().columns(); let list_type_columns = list_col_indices .iter() diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 669bc8e8a7d34..021135a90952e 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -21,12 +21,12 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::logical_plan::tree_node::LogicalPlanStats; +use crate::LogicalPlan; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; -use crate::LogicalPlan; - /// Operator that copies the contents of a database to file(s) #[derive(Clone)] pub struct CopyTo { @@ -89,6 +89,12 @@ impl Hash for CopyTo { } } +impl CopyTo { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } +} + /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -128,6 +134,10 @@ impl DmlStatement { pub fn name(&self) -> &str { self.op.name() } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } } // Manual implementation needed because of `table_schema` and `output_schema` fields. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e9ea2170cc7ab..912383e7da798 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -58,6 +58,7 @@ use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::logical_plan::tree_node::LogicalPlanStats; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -133,10 +134,10 @@ pub use datafusion_common::{JoinConstraint, JoinType}; /// assert_eq!(expressions.len(), 2); /// println!("Found expressions: {:?}", expressions); /// // found predicate in the Filter: employee.salary > 1000 -/// let salary = Expr::Column(Column::new(Some("employee"), "salary")); +/// let salary = Expr::column(Column::new(Some("employee"), "salary")); /// assert!(expressions.contains(&salary.gt(lit(1000)))); /// // found projection in the Projection: employee.name -/// let name = Expr::Column(Column::new(Some("employee"), "name")); +/// let name = Expr::column(Column::new(Some("employee"), "name")); /// assert!(expressions.contains(&name)); /// # Ok(()) /// # } @@ -170,10 +171,10 @@ pub use datafusion_common::{JoinConstraint, JoinType}; /// // use transform to rewrite the plan /// let transformed_result = plan.transform(|node| { /// // when we see the filter node -/// if let LogicalPlan::Filter(mut filter) = node { +/// if let LogicalPlan::Filter(mut filter, _) = node { /// // replace predicate with salary < 2000 -/// filter.predicate = Expr::Column(Column::new(Some("employee"), "salary")).lt(lit(2000)); -/// let new_plan = LogicalPlan::Filter(filter); +/// filter.predicate = Expr::column(Column::new(Some("employee"), "salary")).lt(lit(2000)); +/// let new_plan = LogicalPlan::filter(filter); /// return Ok(Transformed::yes(new_plan)); // communicate the node was changed /// } /// // return the node unchanged @@ -198,7 +199,7 @@ pub use datafusion_common::{JoinConstraint, JoinType}; pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. - Projection(Projection), + Projection(Projection, LogicalPlanStats), /// Filters rows from its input that do not match an /// expression (essentially a WHERE clause with a predicate /// expression). @@ -207,82 +208,269 @@ pub enum LogicalPlan { /// input; If the value of `` is true, the input row is /// passed to the output. If the value of `` is false /// (or null), the row is discarded. - Filter(Filter), + Filter(Filter, LogicalPlanStats), /// Windows input based on a set of window spec and window /// function (e.g. SUM or RANK). This is used to implement SQL /// window functions, and the `OVER` clause. - Window(Window), + Window(Window, LogicalPlanStats), /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). This is used to implement SQL aggregates /// and `GROUP BY`. - Aggregate(Aggregate), + Aggregate(Aggregate, LogicalPlanStats), /// Sorts its input according to a list of sort expressions. This /// is used to implement SQL `ORDER BY` - Sort(Sort), + Sort(Sort, LogicalPlanStats), /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` - Join(Join), + Join(Join, LogicalPlanStats), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems - Repartition(Repartition), + Repartition(Repartition, LogicalPlanStats), /// Union multiple inputs with the same schema into a single /// output stream. This is used to implement SQL `UNION [ALL]` and /// `INTERSECT [ALL]`. - Union(Union), + Union(Union, LogicalPlanStats), /// Produces rows from a [`TableSource`], used to implement SQL /// `FROM` tables or views. - TableScan(TableScan), + TableScan(TableScan, LogicalPlanStats), /// Produces no rows: An empty relation with an empty schema that /// produces 0 or 1 row. This is used to implement SQL `SELECT` /// that has no values in the `FROM` clause. - EmptyRelation(EmptyRelation), + EmptyRelation(EmptyRelation, LogicalPlanStats), /// Produces the output of running another query. This is used to /// implement SQL subqueries - Subquery(Subquery), + Subquery(Subquery, LogicalPlanStats), /// Aliased relation provides, or changes, the name of a relation. - SubqueryAlias(SubqueryAlias), + SubqueryAlias(SubqueryAlias, LogicalPlanStats), /// Skip some number of rows, and then fetch some number of rows. - Limit(Limit), + Limit(Limit, LogicalPlanStats), /// A DataFusion [`Statement`] such as `SET VARIABLE` or `START TRANSACTION` - Statement(Statement), + Statement(Statement, LogicalPlanStats), /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. This is used to implement SQL such as /// `VALUES (1, 2), (3, 4)` - Values(Values), + Values(Values, LogicalPlanStats), /// Produces a relation with string representations of /// various parts of the plan. This is used to implement SQL `EXPLAIN`. - Explain(Explain), + Explain(Explain, LogicalPlanStats), /// Runs the input, and prints annotated physical plan as a string /// with execution metric. This is used to implement SQL /// `EXPLAIN ANALYZE`. - Analyze(Analyze), + Analyze(Analyze, LogicalPlanStats), /// Extension operator defined outside of DataFusion. This is used /// to extend DataFusion with custom relational operations that - Extension(Extension), + Extension(Extension, LogicalPlanStats), /// Remove duplicate rows from the input. This is used to /// implement SQL `SELECT DISTINCT ...`. - Distinct(Distinct), + Distinct(Distinct, LogicalPlanStats), /// Data Manipulation Language (DML): Insert / Update / Delete - Dml(DmlStatement), + Dml(DmlStatement, LogicalPlanStats), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS - Ddl(DdlStatement), + Ddl(DdlStatement, LogicalPlanStats), /// `COPY TO` for writing plan results to files - Copy(CopyTo), + Copy(CopyTo, LogicalPlanStats), /// Describe the schema of the table. This is used to implement the /// SQL `DESCRIBE` command from MySQL. - DescribeTable(DescribeTable), + DescribeTable(DescribeTable, LogicalPlanStats), /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` - Unnest(Unnest), + Unnest(Unnest, LogicalPlanStats), /// A variadic query (e.g. "Recursive CTEs") - RecursiveQuery(RecursiveQuery), + RecursiveQuery(RecursiveQuery, LogicalPlanStats), } +// impl From for LogicalPlan { +// fn from(projection: Projection) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Projection(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Filter) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Filter(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Window) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Window(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Aggregate) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Aggregate(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Sort) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Sort(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Join) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Join(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Repartition) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Repartition(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Union) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Union(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: TableScan) -> Self { +// let stats = projection.stats(); +// LogicalPlan::TableScan(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: EmptyRelation) -> Self { +// LogicalPlan::EmptyRelation(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Subquery) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Subquery(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: SubqueryAlias) -> Self { +// let stats = projection.stats(); +// LogicalPlan::SubqueryAlias(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Limit) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Limit(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Statement) -> Self { +// LogicalPlan::Statement(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Values) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Values(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Explain) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Explain(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Analyze) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Analyze(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Extension) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Extension(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Distinct) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Distinct(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Prepare) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Prepare(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Execute) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Execute(projection, stats) +// } +// } +// +// +// impl From for LogicalPlan { +// fn from(projection: DmlStatement) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Dml(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: DdlStatement) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Ddl(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: CopyTo) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Copy(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: DescribeTable) -> Self { +// LogicalPlan::DescribeTable(projection) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: Unnest) -> Self { +// let stats = projection.stats(); +// LogicalPlan::Unnest(projection, stats) +// } +// } +// +// impl From for LogicalPlan { +// fn from(projection: RecursiveQuery) -> Self { +// let stats = projection.stats(); +// LogicalPlan::RecursiveQuery(projection, stats) +// } +// } + impl Default for LogicalPlan { fn default() -> Self { - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }) @@ -309,36 +497,39 @@ impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { match self { - LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema, - LogicalPlan::Values(Values { schema, .. }) => schema, - LogicalPlan::TableScan(TableScan { - projected_schema, .. - }) => projected_schema, - LogicalPlan::Projection(Projection { schema, .. }) => schema, - LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), - LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, - LogicalPlan::Window(Window { schema, .. }) => schema, - LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, - LogicalPlan::Sort(Sort { input, .. }) => input.schema(), - LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), - LogicalPlan::Limit(Limit { input, .. }) => input.schema(), - LogicalPlan::Statement(statement) => statement.schema(), - LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(), - LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema, - LogicalPlan::Explain(explain) => &explain.schema, - LogicalPlan::Analyze(analyze) => &analyze.schema, - LogicalPlan::Extension(extension) => extension.node.schema(), - LogicalPlan::Union(Union { schema, .. }) => schema, - LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => { + LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }, _) => schema, + LogicalPlan::Values(Values { schema, .. }, _) => schema, + LogicalPlan::TableScan( + TableScan { + projected_schema, .. + }, + _, + ) => projected_schema, + LogicalPlan::Projection(Projection { schema, .. }, _) => schema, + LogicalPlan::Filter(Filter { input, .. }, _) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input), _) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. }), _) => schema, + LogicalPlan::Window(Window { schema, .. }, _) => schema, + LogicalPlan::Aggregate(Aggregate { schema, .. }, _) => schema, + LogicalPlan::Sort(Sort { input, .. }, _) => input.schema(), + LogicalPlan::Join(Join { schema, .. }, _) => schema, + LogicalPlan::Repartition(Repartition { input, .. }, _) => input.schema(), + LogicalPlan::Limit(Limit { input, .. }, _) => input.schema(), + LogicalPlan::Statement(statement, _) => statement.schema(), + LogicalPlan::Subquery(Subquery { subquery, .. }, _) => subquery.schema(), + LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }, _) => schema, + LogicalPlan::Explain(explain, _) => &explain.schema, + LogicalPlan::Analyze(analyze, _) => &analyze.schema, + LogicalPlan::Extension(extension, _) => extension.node.schema(), + LogicalPlan::Union(Union { schema, .. }, _) => schema, + LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }, _) => { output_schema } - LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema, - LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), - LogicalPlan::Ddl(ddl) => ddl.schema(), - LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + LogicalPlan::Dml(DmlStatement { output_schema, .. }, _) => output_schema, + LogicalPlan::Copy(CopyTo { input, .. }, _) => input.schema(), + LogicalPlan::Ddl(ddl, _) => ddl.schema(), + LogicalPlan::Unnest(Unnest { schema, .. }, _) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }, _) => { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } @@ -349,11 +540,11 @@ impl LogicalPlan { /// of the plan. pub fn fallback_normalize_schemas(&self) -> Vec<&DFSchema> { match self { - LogicalPlan::Window(_) - | LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) => self + LogicalPlan::Window { .. } + | LogicalPlan::Projection { .. } + | LogicalPlan::Aggregate { .. } + | LogicalPlan::Unnest { .. } + | LogicalPlan::Join { .. } => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -436,40 +627,44 @@ impl LogicalPlan { /// Note does not include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { match self { - LogicalPlan::Projection(Projection { input, .. }) => vec![input], - LogicalPlan::Filter(Filter { input, .. }) => vec![input], - LogicalPlan::Repartition(Repartition { input, .. }) => vec![input], - LogicalPlan::Window(Window { input, .. }) => vec![input], - LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], - LogicalPlan::Sort(Sort { input, .. }) => vec![input], - LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::Limit(Limit { input, .. }) => vec![input], - LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], - LogicalPlan::Extension(extension) => extension.node.inputs(), - LogicalPlan::Union(Union { inputs, .. }) => { + LogicalPlan::Projection(Projection { input, .. }, _) => vec![input], + LogicalPlan::Filter(Filter { input, .. }, _) => vec![input], + LogicalPlan::Repartition(Repartition { input, .. }, _) => vec![input], + LogicalPlan::Window(Window { input, .. }, _) => vec![input], + LogicalPlan::Aggregate(Aggregate { input, .. }, _) => vec![input], + LogicalPlan::Sort(Sort { input, .. }, _) => vec![input], + LogicalPlan::Join(Join { left, right, .. }, _) => vec![left, right], + LogicalPlan::Limit(Limit { input, .. }, _) => vec![input], + LogicalPlan::Subquery(Subquery { subquery, .. }, _) => vec![subquery], + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }, _) => vec![input], + LogicalPlan::Extension(extension, _) => extension.node.inputs(), + LogicalPlan::Union(Union { inputs, .. }, _) => { inputs.iter().map(|arc| arc.as_ref()).collect() } LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + _, ) => vec![input], - LogicalPlan::Explain(explain) => vec![&explain.plan], - LogicalPlan::Analyze(analyze) => vec![&analyze.input], - LogicalPlan::Dml(write) => vec![&write.input], - LogicalPlan::Copy(copy) => vec![©.input], - LogicalPlan::Ddl(ddl) => ddl.inputs(), - LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], - LogicalPlan::RecursiveQuery(RecursiveQuery { - static_term, - recursive_term, - .. - }) => vec![static_term, recursive_term], - LogicalPlan::Statement(stmt) => stmt.inputs(), + LogicalPlan::Explain(explain, _) => vec![&explain.plan], + LogicalPlan::Analyze(analyze, _) => vec![&analyze.input], + LogicalPlan::Dml(write, _) => vec![&write.input], + LogicalPlan::Copy(copy, _) => vec![©.input], + LogicalPlan::Ddl(ddl, _) => ddl.inputs(), + LogicalPlan::Unnest(Unnest { input, .. }, _) => vec![input], + LogicalPlan::RecursiveQuery( + RecursiveQuery { + static_term, + recursive_term, + .. + }, + _, + ) => vec![static_term, recursive_term], + LogicalPlan::Statement(stmt, _) => stmt.inputs(), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } - | LogicalPlan::DescribeTable(_) => vec![], + | LogicalPlan::DescribeTable(_, _) => vec![], } } @@ -478,11 +673,14 @@ impl LogicalPlan { let mut using_columns: Vec> = vec![]; self.apply_with_subqueries(|plan| { - if let LogicalPlan::Join(Join { - join_constraint: JoinConstraint::Using, - on, - .. - }) = plan + if let LogicalPlan::Join( + Join { + join_constraint: JoinConstraint::Using, + on, + .. + }, + _, + ) = plan { // The join keys in using-join must be columns. let columns = @@ -512,31 +710,34 @@ impl LogicalPlan { /// returns the first output expression of this `LogicalPlan` node. pub fn head_output_expr(&self) -> Result> { match self { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { Ok(Some(projection.expr.as_slice()[0].clone())) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { if agg.group_expr.is_empty() { Ok(Some(agg.aggr_expr.as_slice()[0].clone())) } else { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } - LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. }), _) => { Ok(Some(select_expr[0].clone())) } - LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct::All(input)) - | LogicalPlan::Sort(Sort { input, .. }) - | LogicalPlan::Limit(Limit { input, .. }) - | LogicalPlan::Repartition(Repartition { input, .. }) - | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Filter(Filter { input, .. }, _) + | LogicalPlan::Distinct(Distinct::All(input), _) + | LogicalPlan::Sort(Sort { input, .. }, _) + | LogicalPlan::Limit(Limit { input, .. }, _) + | LogicalPlan::Repartition(Repartition { input, .. }, _) + | LogicalPlan::Window(Window { input, .. }, _) => input.head_output_expr(), + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { if left.schema().fields().is_empty() { right.head_output_expr() @@ -549,38 +750,38 @@ impl LogicalPlan { } JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }, _) => { static_term.head_output_expr() } - LogicalPlan::Union(union) => Ok(Some(Expr::Column(Column::from( + LogicalPlan::Union(union, _) => Ok(Some(Expr::column(Column::from( union.schema.qualified_field(0), )))), - LogicalPlan::TableScan(table) => Ok(Some(Expr::Column(Column::from( + LogicalPlan::TableScan(table, _) => Ok(Some(Expr::column(Column::from( table.projected_schema.qualified_field(0), )))), - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { let expr_opt = subquery_alias.input.head_output_expr()?; expr_opt .map(|expr| { - Ok(Expr::Column(create_col_from_scalar_expr( + Ok(Expr::column(create_col_from_scalar_expr( &expr, subquery_alias.alias.to_string(), )?)) }) .map_or(Ok(None), |v| v.map(Some)) } - LogicalPlan::Subquery(_) => Ok(None), - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Values(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::Subquery { .. } => Ok(None), + LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::Explain { .. } + | LogicalPlan::Analyze { .. } + | LogicalPlan::Extension { .. } + | LogicalPlan::Dml { .. } + | LogicalPlan::Copy { .. } + | LogicalPlan::Ddl { .. } + | LogicalPlan::DescribeTable { .. } + | LogicalPlan::Unnest { .. } => Ok(None), } } @@ -610,47 +811,62 @@ impl LogicalPlan { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. - LogicalPlan::Projection(Projection { - expr, - input, - schema: _, - }) => Projection::try_new(expr, input).map(LogicalPlan::Projection), - LogicalPlan::Dml(_) => Ok(self), - LogicalPlan::Copy(_) => Ok(self), - LogicalPlan::Values(Values { schema, values }) => { + LogicalPlan::Projection( + Projection { + expr, + input, + schema: _, + }, + _, + ) => Projection::try_new(expr, input).map(LogicalPlan::projection), + LogicalPlan::Dml { .. } => Ok(self), + LogicalPlan::Copy { .. } => Ok(self), + LogicalPlan::Values(Values { schema, values }, _) => { // todo it isn't clear why the schema is not recomputed here - Ok(LogicalPlan::Values(Values { schema, values })) + Ok(LogicalPlan::values(Values { schema, values })) } - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => Filter::try_new_internal(predicate, input, having) - .map(LogicalPlan::Filter), - LogicalPlan::Repartition(_) => Ok(self), - LogicalPlan::Window(Window { - input, - window_expr, - schema: _, - }) => Window::try_new(window_expr, input).map(LogicalPlan::Window), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema: _, - }) => Aggregate::try_new(input, group_expr, aggr_expr) - .map(LogicalPlan::Aggregate), - LogicalPlan::Sort(_) => Ok(self), - LogicalPlan::Join(Join { - left, - right, - filter, - join_type, - join_constraint, - on, - schema: _, - null_equals_null, - }) => { + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => Filter::try_new_internal(predicate, input, having) + .map(LogicalPlan::filter), + LogicalPlan::Repartition { .. } => Ok(self), + LogicalPlan::Window( + Window { + input, + window_expr, + schema: _, + }, + _, + ) => Window::try_new(window_expr, input).map(LogicalPlan::window), + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema: _, + }, + _, + ) => Aggregate::try_new(input, group_expr, aggr_expr) + .map(LogicalPlan::aggregate), + LogicalPlan::Sort { .. } => Ok(self), + LogicalPlan::Join( + Join { + left, + right, + filter, + join_type, + join_constraint, + on, + schema: _, + null_equals_null, + }, + _, + ) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -662,7 +878,7 @@ impl LogicalPlan { }) .collect(); - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left, right, join_type, @@ -673,24 +889,27 @@ impl LogicalPlan { null_equals_null, })) } - LogicalPlan::Subquery(_) => Ok(self), - LogicalPlan::SubqueryAlias(SubqueryAlias { - input, - alias, - schema: _, - }) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias), - LogicalPlan::Limit(_) => Ok(self), - LogicalPlan::Ddl(_) => Ok(self), - LogicalPlan::Extension(Extension { node }) => { + LogicalPlan::Subquery { .. } => Ok(self), + LogicalPlan::SubqueryAlias( + SubqueryAlias { + input, + alias, + schema: _, + }, + _, + ) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::subquery_alias), + LogicalPlan::Limit { .. } => Ok(self), + LogicalPlan::Ddl { .. } => Ok(self), + LogicalPlan::Extension(Extension { node }, _) => { // todo make an API that does not require cloning // This requires a copy of the extension nodes expressions and inputs let expr = node.expressions(); let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); - Ok(LogicalPlan::Extension(Extension { + Ok(LogicalPlan::extension(Extension { node: node.with_exprs_and_inputs(expr, inputs)?, })) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema // TODO this seems wrong (shouldn't we always use the schema of the input?) @@ -699,9 +918,9 @@ impl LogicalPlan { } else { Arc::clone(input_schema) }; - Ok(LogicalPlan::Union(Union { inputs, schema })) + Ok(LogicalPlan::union(Union { inputs, schema })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let distinct = match distinct { Distinct::All(input) => Distinct::All(input), Distinct::On(DistinctOn { @@ -717,21 +936,24 @@ impl LogicalPlan { input, )?), }; - Ok(LogicalPlan::Distinct(distinct)) + Ok(LogicalPlan::distinct(distinct)) } - LogicalPlan::RecursiveQuery(_) => Ok(self), - LogicalPlan::Analyze(_) => Ok(self), - LogicalPlan::Explain(_) => Ok(self), - LogicalPlan::TableScan(_) => Ok(self), - LogicalPlan::EmptyRelation(_) => Ok(self), - LogicalPlan::Statement(_) => Ok(self), - LogicalPlan::DescribeTable(_) => Ok(self), - LogicalPlan::Unnest(Unnest { - input, - exec_columns, - options, - .. - }) => { + LogicalPlan::RecursiveQuery { .. } => Ok(self), + LogicalPlan::Analyze { .. } => Ok(self), + LogicalPlan::Explain { .. } => Ok(self), + LogicalPlan::TableScan { .. } => Ok(self), + LogicalPlan::EmptyRelation { .. } => Ok(self), + LogicalPlan::Statement { .. } => Ok(self), + LogicalPlan::DescribeTable { .. } => Ok(self), + LogicalPlan::Unnest( + Unnest { + input, + exec_columns, + options, + .. + }, + _, + ) => { // Update schema with unnested column type. unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) } @@ -771,35 +993,41 @@ impl LogicalPlan { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. - LogicalPlan::Projection(Projection { .. }) => { + LogicalPlan::Projection(Projection { .. }, _) => { let input = self.only_input(inputs)?; - Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::projection) } - LogicalPlan::Dml(DmlStatement { - table_name, - table_schema, - op, - .. - }) => { + LogicalPlan::Dml( + DmlStatement { + table_name, + table_schema, + op, + .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Dml(DmlStatement::new( + Ok(LogicalPlan::dml(DmlStatement::new( table_name.clone(), Arc::clone(table_schema), op.clone(), Arc::new(input), ))) } - LogicalPlan::Copy(CopyTo { - input: _, - output_url, - file_type, - options, - partition_by, - }) => { + LogicalPlan::Copy( + CopyTo { + input: _, + output_url, + file_type, + options, + partition_by, + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Copy(CopyTo { + Ok(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: output_url.clone(), file_type: Arc::clone(file_type), @@ -807,9 +1035,9 @@ impl LogicalPlan { partition_by: partition_by.clone(), })) } - LogicalPlan::Values(Values { schema, .. }) => { + LogicalPlan::Values(Values { schema, .. }, _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::Values(Values { + Ok(LogicalPlan::values(Values { schema: Arc::clone(schema), values: expr .chunks_exact(schema.fields().len()) @@ -821,55 +1049,61 @@ impl LogicalPlan { let predicate = self.only_expr(expr)?; let input = self.only_input(inputs)?; - Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter) + Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::filter) } - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::RoundRobinBatch(*n), input: Arc::new(input), })) } Partitioning::Hash(_, n) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::Hash(expr, *n), input: Arc::new(input), })) } Partitioning::DistributeBy(_) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { partitioning_scheme: Partitioning::DistributeBy(expr), input: Arc::new(input), })) } }, - LogicalPlan::Window(Window { window_expr, .. }) => { + LogicalPlan::Window(Window { window_expr, .. }, _) => { assert_eq!(window_expr.len(), expr.len()); let input = self.only_input(inputs)?; - Window::try_new(expr, Arc::new(input)).map(LogicalPlan::Window) + Window::try_new(expr, Arc::new(input)).map(LogicalPlan::window) } - LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { + LogicalPlan::Aggregate(Aggregate { group_expr, .. }, _) => { let input = self.only_input(inputs)?; // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); Aggregate::try_new(Arc::new(input), expr, agg_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) } - LogicalPlan::Sort(Sort { - expr: sort_expr, - fetch, - .. - }) => { + LogicalPlan::Sort( + Sort { + expr: sort_expr, + fetch, + .. + }, + _, + ) => { let input = self.only_input(inputs)?; - Ok(LogicalPlan::Sort(Sort { + Ok(LogicalPlan::sort(Sort { expr: expr .into_iter() .zip(sort_expr.iter()) @@ -879,13 +1113,16 @@ impl LogicalPlan { fetch: *fetch, })) } - LogicalPlan::Join(Join { - join_type, - join_constraint, - on, - null_equals_null, - .. - }) => { + LogicalPlan::Join( + Join { + join_type, + join_constraint, + on, + null_equals_null, + .. + }, + _, + ) => { let (left, right) = self.only_two_inputs(inputs)?; let schema = build_join_schema(left.schema(), right.schema(), join_type)?; @@ -906,7 +1143,7 @@ impl LogicalPlan { let new_on = expr.into_iter().map(|equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }, _) = unalias_expr { Ok((*left, *right)) } else { internal_err!( @@ -915,7 +1152,7 @@ impl LogicalPlan { } }).collect::>>()?; - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left: Arc::new(left), right: Arc::new(right), join_type: *join_type, @@ -926,24 +1163,27 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::Subquery(Subquery { - outer_ref_columns, .. - }) => { + LogicalPlan::Subquery( + Subquery { + outer_ref_columns, .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; let subquery = LogicalPlanBuilder::from(input).build()?; - Ok(LogicalPlan::Subquery(Subquery { + Ok(LogicalPlan::subquery(Subquery { subquery: Arc::new(subquery), outer_ref_columns: outer_ref_columns.clone(), })) } - LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; SubqueryAlias::try_new(Arc::new(input), alias.clone()) - .map(LogicalPlan::SubqueryAlias) + .map(LogicalPlan::subquery_alias) } - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(Limit { skip, fetch, .. }, _) => { let old_expr_len = skip.iter().chain(fetch.iter()).count(); if old_expr_len != expr.len() { return internal_err!( @@ -955,23 +1195,26 @@ impl LogicalPlan { let new_skip = skip.as_ref().and_then(|_| expr.pop()); let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); let input = self.only_input(inputs)?; - Ok(LogicalPlan::Limit(Limit { + Ok(LogicalPlan::limit(Limit { skip: new_skip.map(Box::new), fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - name, - if_not_exists, - or_replace, - column_defaults, - temporary, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateMemoryTable(CreateMemoryTable { + name, + if_not_exists, + or_replace, + column_defaults, + temporary, + .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(input), constraints: Constraints::empty(), @@ -983,16 +1226,19 @@ impl LogicalPlan { }, ))) } - LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - name, - or_replace, - definition, - temporary, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateView(CreateView { + name, + or_replace, + definition, + temporary, + .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { input: Arc::new(input), name: name.clone(), or_replace: *or_replace, @@ -1000,10 +1246,10 @@ impl LogicalPlan { definition: definition.clone(), }))) } - LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { + LogicalPlan::Extension(e, _) => Ok(LogicalPlan::extension(Extension { node: e.node.with_exprs_and_inputs(expr, inputs)?, })), - LogicalPlan::Union(Union { schema, .. }) => { + LogicalPlan::Union(Union { schema, .. }, _) => { self.assert_no_expressions(expr)?; let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema. @@ -1012,12 +1258,12 @@ impl LogicalPlan { } else { Arc::clone(input_schema) }; - Ok(LogicalPlan::Union(Union { + Ok(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect(), schema, })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let distinct = match distinct { Distinct::All(_) => { self.assert_no_expressions(expr)?; @@ -1041,33 +1287,36 @@ impl LogicalPlan { )?) } }; - Ok(LogicalPlan::Distinct(distinct)) + Ok(LogicalPlan::distinct(distinct)) } - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. - }) => { + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, is_distinct, .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(LogicalPlan::recursive_query(RecursiveQuery { name: name.clone(), static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), is_distinct: *is_distinct, })) } - LogicalPlan::Analyze(a) => { + LogicalPlan::Analyze(a, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Analyze(Analyze { + Ok(LogicalPlan::analyze(Analyze { verbose: a.verbose, schema: Arc::clone(&a.schema), input: Arc::new(input), })) } - LogicalPlan::Explain(e) => { + LogicalPlan::Explain(e, _) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose: e.verbose, plan: Arc::new(input), stringified_plans: e.stringified_plans.clone(), @@ -1075,47 +1324,51 @@ impl LogicalPlan { logical_optimization_succeeded: e.logical_optimization_succeeded, })) } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - .. - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, data_types, .. + }), + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Statement(Statement::Prepare(Prepare { + Ok(LogicalPlan::statement(Statement::Prepare(Prepare { name: name.clone(), data_types: data_types.clone(), input: Arc::new(input), }))) } - LogicalPlan::Statement(Statement::Execute(Execute { name, .. })) => { + LogicalPlan::Statement(Statement::Execute(Execute { name, .. }), _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::Statement(Statement::Execute(Execute { + Ok(LogicalPlan::statement(Statement::Execute(Execute { name: name.clone(), parameters: expr, }))) } - LogicalPlan::TableScan(ts) => { + LogicalPlan::TableScan(ts, _) => { self.assert_no_inputs(inputs)?; - Ok(LogicalPlan::TableScan(TableScan { + Ok(LogicalPlan::table_scan(TableScan { filters: expr, ..ts.clone() })) } - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Statement(_) - | LogicalPlan::DescribeTable(_) => { + LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Ddl { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::DescribeTable { .. } => { // All of these plan types have no inputs / exprs so should not be called self.assert_no_expressions(expr)?; self.assert_no_inputs(inputs)?; Ok(self.clone()) } - LogicalPlan::Unnest(Unnest { - exec_columns: columns, - options, - .. - }) => { + LogicalPlan::Unnest( + Unnest { + exec_columns: columns, + options, + .. + }, + _, + ) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; // Update schema with unnested column type. @@ -1249,7 +1502,7 @@ impl LogicalPlan { // unwrap Prepare Ok( - if let LogicalPlan::Statement(Statement::Prepare(prepare_lp)) = + if let LogicalPlan::Statement(Statement::Prepare(prepare_lp), _) = plan_with_values { param_values.verify(&prepare_lp.data_types)?; @@ -1267,29 +1520,32 @@ impl LogicalPlan { /// If `Some(n)` then the plan can return at most `n` rows but may return fewer. pub fn max_rows(self: &LogicalPlan) -> Option { match self { - LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(filter) => { + LogicalPlan::Projection(Projection { input, .. }, _) => input.max_rows(), + LogicalPlan::Filter(filter, _) => { if filter.is_scalar() { Some(1) } else { filter.input.max_rows() } } - LogicalPlan::Window(Window { input, .. }) => input.max_rows(), - LogicalPlan::Aggregate(Aggregate { - input, group_expr, .. - }) => { + LogicalPlan::Window(Window { input, .. }, _) => input.max_rows(), + LogicalPlan::Aggregate( + Aggregate { + input, group_expr, .. + }, + _, + ) => { // Empty group_expr will return Some(1) if group_expr .iter() - .all(|expr| matches!(expr, Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Literal(_, _))) { Some(1) } else { input.max_rows() } } - LogicalPlan::Sort(Sort { input, fetch, .. }) => { + LogicalPlan::Sort(Sort { input, fetch, .. }, _) => { match (fetch, input.max_rows()) { (Some(fetch_limit), Some(input_max)) => { Some(input_max.min(*fetch_limit)) @@ -1299,12 +1555,15 @@ impl LogicalPlan { (None, None) => None, } } - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { match (left.max_rows(), right.max_rows()) { (Some(left_max), Some(right_max)) => { @@ -1324,8 +1583,8 @@ impl LogicalPlan { } JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), - LogicalPlan::Union(Union { inputs, .. }) => inputs + LogicalPlan::Repartition(Repartition { input, .. }, _) => input.max_rows(), + LogicalPlan::Union(Union { inputs, .. }, _) => inputs .iter() .map(|plan| plan.max_rows()) .try_fold(0usize, |mut acc, input_max| { @@ -1336,28 +1595,31 @@ impl LogicalPlan { None } }), - LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, - LogicalPlan::EmptyRelation(_) => Some(0), - LogicalPlan::RecursiveQuery(_) => None, - LogicalPlan::Subquery(_) => None, - LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + LogicalPlan::TableScan(TableScan { fetch, .. }, _) => *fetch, + LogicalPlan::EmptyRelation(_, _) => Some(0), + LogicalPlan::RecursiveQuery(_, _) => None, + LogicalPlan::Subquery(_, _) => None, + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }, _) => { + input.max_rows() + } + LogicalPlan::Limit(limit, _) => match limit.get_fetch_type() { Ok(FetchType::Literal(s)) => s, _ => None, }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + _, ) => input.max_rows(), - LogicalPlan::Values(v) => Some(v.values.len()), - LogicalPlan::Unnest(_) => None, - LogicalPlan::Ddl(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Extension(_) => None, + LogicalPlan::Values(v, _) => Some(v.values.len()), + LogicalPlan::Unnest(_, _) => None, + LogicalPlan::Ddl(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::DescribeTable(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::Extension(_, _) => None, } } @@ -1385,16 +1647,19 @@ impl LogicalPlan { /// See also: [`crate::utils::columnize_expr`] pub fn columnized_output_exprs(&self) -> Result> { match self { - LogicalPlan::Aggregate(aggregate) => Ok(aggregate + LogicalPlan::Aggregate(aggregate, _) => Ok(aggregate .output_expressions()? .into_iter() .zip(self.schema().columns()) .collect()), - LogicalPlan::Window(Window { - window_expr, - input, - schema, - }) => { + LogicalPlan::Window( + Window { + window_expr, + input, + schema, + }, + _, + ) => { // The input could be another Window, so the result should also include the input's. For Example: // `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b), SUM(b) OVER (PARTITION BY a) FROM t` // Its plan is: @@ -1440,9 +1705,9 @@ impl LogicalPlan { } else { let original_name = name_preserver.save(&e); let transformed_expr = e.transform_up(|e| { - if let Expr::Placeholder(Placeholder { id, .. }) = e { + if let Expr::Placeholder(Placeholder { id, .. }, _) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::literal(value))) } else { Ok(Transformed::no(e)) } @@ -1461,7 +1726,7 @@ impl LogicalPlan { self.apply_with_subqueries(|plan| { plan.apply_expressions(|expr| { expr.apply(|expr| { - if let Expr::Placeholder(Placeholder { id, .. }) = expr { + if let Expr::Placeholder(Placeholder { id, .. }, _) = expr { param_names.insert(id.clone()); } Ok(TreeNodeRecursion::Continue) @@ -1480,7 +1745,7 @@ impl LogicalPlan { self.apply_with_subqueries(|plan| { plan.apply_expressions(|expr| { expr.apply(|expr| { - if let Expr::Placeholder(Placeholder { id, data_type }) = expr { + if let Expr::Placeholder(Placeholder { id, data_type }, _) = expr { let prev = param_types.get(id); match (prev, data_type) { (Some(Some(prev)), Some(dt)) => { @@ -1699,13 +1964,13 @@ impl LogicalPlan { impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { - LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::EmptyRelation(_, _) => write!(f, "EmptyRelation"), LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. - }) => { + }, _) => { write!(f, "RecursiveQuery: is_distinct={}", is_distinct) } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { ref values, .. }, _) => { let str_values: Vec<_> = values .iter() // limit to only 5 values to avoid horrible display @@ -1731,7 +1996,7 @@ impl LogicalPlan { ref filters, ref fetch, .. - }) => { + }, _) => { let projected_fields = match projection { Some(indices) => { let schema = source.schema(); @@ -1799,7 +2064,7 @@ impl LogicalPlan { Ok(()) } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { ref expr, .. }, _) => { write!(f, "Projection: ")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -1809,7 +2074,7 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + LogicalPlan::Dml(DmlStatement { table_name, op, .. }, _) => { write!(f, "Dml: op=[{op}] table=[{table_name}]") } LogicalPlan::Copy(CopyTo { @@ -1818,7 +2083,7 @@ impl LogicalPlan { file_type, options, .. - }) => { + }, _) => { let op_str = options .iter() .map(|(k, v)| format!("{k} {v}")) @@ -1827,16 +2092,16 @@ impl LogicalPlan { write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) } - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { write!(f, "{}", ddl.display()) } LogicalPlan::Filter(Filter { predicate: ref expr, .. - }) => write!(f, "Filter: {expr}"), + }, _) => write!(f, "Filter: {expr}"), LogicalPlan::Window(Window { ref window_expr, .. - }) => { + }, _) => { write!( f, "WindowAggr: windowExpr=[[{}]]", @@ -1847,13 +2112,13 @@ impl LogicalPlan { ref group_expr, ref aggr_expr, .. - }) => write!( + }, _) => write!( f, "Aggregate: groupBy=[[{}]], aggr=[[{}]]", expr_vec_fmt!(group_expr), expr_vec_fmt!(aggr_expr) ), - LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + LogicalPlan::Sort(Sort { expr, fetch, .. }, _) => { write!(f, "Sort: ")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -1873,7 +2138,7 @@ impl LogicalPlan { join_constraint, join_type, .. - }) => { + }, _) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); let filter_expr = filter @@ -1909,7 +2174,7 @@ impl LogicalPlan { LogicalPlan::Repartition(Repartition { partitioning_scheme, .. - }) => match partitioning_scheme { + }, _) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { write!(f, "Repartition: RoundRobinBatch partition_count={n}") } @@ -1933,7 +2198,7 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. let skip_str = match limit.get_skip_type() { Ok(SkipType::Literal(n)) => n.to_string(), @@ -1949,16 +2214,16 @@ impl LogicalPlan { "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } - LogicalPlan::Subquery(Subquery { .. }) => { + LogicalPlan::Subquery(Subquery { .. }, _) => { write!(f, "Subquery:") } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }, _) => { write!(f, "SubqueryAlias: {alias}") } - LogicalPlan::Statement(statement) => { + LogicalPlan::Statement(statement, _) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(distinct) => match distinct { + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(_) => write!(f, "Distinct:"), Distinct::On(DistinctOn { on_expr, @@ -1975,15 +2240,15 @@ impl LogicalPlan { }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), - LogicalPlan::Union(_) => write!(f, "Union"), - LogicalPlan::Extension(e) => e.node.fmt_for_explain(f), - LogicalPlan::DescribeTable(DescribeTable { .. }) => { + LogicalPlan::Union(_, _) => write!(f, "Union"), + LogicalPlan::Extension(e, _) => e.node.fmt_for_explain(f), + LogicalPlan::DescribeTable(DescribeTable { .. }, _) => { write!(f, "DescribeTable") } LogicalPlan::Unnest(Unnest { input: plan, list_type_columns: list_col_indices, - struct_type_columns: struct_col_indices, .. }) => { + struct_type_columns: struct_col_indices, .. }, _) => { let input_columns = plan.schema().columns(); let list_type_columns = list_col_indices .iter() @@ -2005,6 +2270,131 @@ impl LogicalPlan { } Wrapper(self) } + + pub fn projection(projection: Projection) -> Self { + let stats = projection.stats(); + LogicalPlan::Projection(projection, stats) + } + + pub fn filter(filter: Filter) -> Self { + let stats = filter.stats(); + LogicalPlan::Filter(filter, stats) + } + + pub fn statement(statement: Statement) -> Self { + let stats = statement.stats(); + LogicalPlan::Statement(statement, stats) + } + + pub fn window(window: Window) -> Self { + let stats = window.stats(); + LogicalPlan::Window(window, stats) + } + + pub fn aggregate(aggregate: Aggregate) -> Self { + let stats = aggregate.stats(); + LogicalPlan::Aggregate(aggregate, stats) + } + + pub fn sort(sort: Sort) -> Self { + let stats = sort.stats(); + LogicalPlan::Sort(sort, stats) + } + + pub fn join(join: Join) -> Self { + let stats = join.stats(); + LogicalPlan::Join(join, stats) + } + + pub fn repartition(repartition: Repartition) -> Self { + let stats = repartition.stats(); + LogicalPlan::Repartition(repartition, stats) + } + + pub fn union(projection: Union) -> Self { + let stats = projection.stats(); + LogicalPlan::Union(projection, stats) + } + + pub fn table_scan(table_scan: TableScan) -> Self { + let stats = table_scan.stats(); + LogicalPlan::TableScan(table_scan, stats) + } + + pub fn subquery(subquery: Subquery) -> Self { + let stats = subquery.stats(); + LogicalPlan::Subquery(subquery, stats) + } + + pub fn subquery_alias(subquery_alias: SubqueryAlias) -> Self { + let stats = subquery_alias.stats(); + LogicalPlan::SubqueryAlias(subquery_alias, stats) + } + + pub fn limit(limit: Limit) -> Self { + let stats = limit.stats(); + LogicalPlan::Limit(limit, stats) + } + + pub fn values(values: Values) -> Self { + let stats = values.stats(); + LogicalPlan::Values(values, stats) + } + + pub fn explain(explain: Explain) -> Self { + let stats = explain.stats(); + LogicalPlan::Explain(explain, stats) + } + + pub fn analyze(analyze: Analyze) -> Self { + let stats = analyze.stats(); + LogicalPlan::Analyze(analyze, stats) + } + + pub fn extension(extension: Extension) -> Self { + let stats = extension.stats(); + LogicalPlan::Extension(extension, stats) + } + + pub fn distinct(distinct: Distinct) -> Self { + let stats = distinct.stats(); + LogicalPlan::Distinct(distinct, stats) + } + + pub fn dml(dml_statement: DmlStatement) -> Self { + let stats = dml_statement.stats(); + LogicalPlan::Dml(dml_statement, stats) + } + + pub fn ddl(ddl_statement: DdlStatement) -> Self { + let stats = ddl_statement.stats(); + LogicalPlan::Ddl(ddl_statement, stats) + } + + pub fn copy(copy_to: CopyTo) -> Self { + let stats = copy_to.stats(); + LogicalPlan::Copy(copy_to, stats) + } + + pub fn describe_table(describe_table: DescribeTable) -> Self { + let stats = LogicalPlanStats::empty(); + LogicalPlan::DescribeTable(describe_table, stats) + } + + pub fn unnest(unnest: Unnest) -> Self { + let stats = unnest.stats(); + LogicalPlan::Unnest(unnest, stats) + } + + pub fn recursive_query(recursive_query: RecursiveQuery) -> Self { + let stats = recursive_query.stats(); + LogicalPlan::RecursiveQuery(recursive_query, stats) + } + + pub fn empty_relation(empty_relation: EmptyRelation) -> Self { + let stats = LogicalPlanStats::empty(); + LogicalPlan::EmptyRelation(empty_relation, stats) + } } impl Display for LogicalPlan { @@ -2071,6 +2461,12 @@ pub struct RecursiveQuery { pub is_distinct: bool, } +impl RecursiveQuery { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.static_term.stats().merge(self.recursive_term.stats()) + } +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. @@ -2082,6 +2478,15 @@ pub struct Values { pub values: Vec>, } +impl Values { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.values + .iter() + .flatten() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } +} + // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Values { fn partial_cmp(&self, other: &Self) -> Option { @@ -2140,13 +2545,20 @@ impl Projection { /// Create a new Projection using the specified output schema pub fn new_from_schema(input: Arc, schema: DFSchemaRef) -> Self { - let expr: Vec = schema.columns().into_iter().map(Expr::Column).collect(); + let expr: Vec = schema.columns().into_iter().map(Expr::column).collect(); Self { expr, input, schema, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr.iter().fold( + LogicalPlanStats::empty().merge(self.input.stats()), + |s, e| s.merge(e.stats()), + ) + } } /// Computes the schema of the result produced by applying a projection to the input logical plan. @@ -2210,6 +2622,10 @@ impl SubqueryAlias { schema, }) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2328,11 +2744,14 @@ impl Filter { let eq_pred_cols: HashSet<_> = exprs .iter() .filter_map(|expr| { - let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr + let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) = expr else { return None; }; @@ -2342,8 +2761,8 @@ impl Filter { } match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Column(_)) => None, - (Expr::Column(c), _) | (_, Expr::Column(c)) => { + (Expr::Column(_, _), Expr::Column(_, _)) => None, + (Expr::Column(c, _), _) | (_, Expr::Column(c, _)) => { Some(schema.index_of_column(c).unwrap()) } _ => None, @@ -2360,6 +2779,10 @@ impl Filter { } false } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats().merge(self.predicate.stats()) + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -2399,11 +2822,14 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(udwf), - partition_by, - .. - }) = expr + if let Expr::WindowFunction( + WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(udwf), + partition_by, + .. + }, + _, + ) = expr { // When there is no PARTITION BY, row number will be unique // across the entire table. @@ -2457,6 +2883,12 @@ impl Window { schema, }) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.window_expr + .iter() + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2604,6 +3036,12 @@ impl TableScan { fetch, }) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.filters + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } } // Repartition the plan based on a partitioning scheme. @@ -2615,6 +3053,18 @@ pub struct Repartition { pub partitioning_scheme: Partitioning, } +impl Repartition { + pub(crate) fn stats(&self) -> LogicalPlanStats { + let s = self.input.stats(); + match &self.partitioning_scheme { + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().fold(s, |s, e| s.merge(e.stats())) + } + Partitioning::RoundRobinBatch(_) => s, + } + } +} + /// Union multiple inputs #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Union { @@ -2630,6 +3080,13 @@ impl PartialOrd for Union { self.inputs.partial_cmp(&other.inputs) } } +impl Union { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.inputs + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())) + } +} /// Describe the schema of table /// @@ -2716,6 +3173,12 @@ impl PartialOrd for Explain { } } +impl Explain { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.plan.stats() + } +} + /// Runs the actual plan, and then prints the physical plan with /// with execution metrics. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -2738,6 +3201,12 @@ impl PartialOrd for Analyze { } } +impl Analyze { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } +} + /// Extension operator defined outside of DataFusion // TODO(clippy): This clippy `allow` should be removed if // the manual `PartialEq` is removed in favor of a derive. @@ -2764,6 +3233,18 @@ impl PartialOrd for Extension { } } +impl Extension { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.node.inputs().iter().fold( + self.node + .expressions() + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + |s, e| s.merge(e.stats()), + ) + } +} + /// Produces the first `n` tuples from its input and discards the rest. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { @@ -2798,7 +3279,7 @@ impl Limit { pub fn get_skip_type(&self) -> Result { match self.skip.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(s)) => { + Expr::Literal(ScalarValue::Int64(s), _) => { // `skip = NULL` is equivalent to `skip = 0` let s = s.unwrap_or(0); if s >= 0 { @@ -2818,19 +3299,28 @@ impl Limit { pub fn get_fetch_type(&self) -> Result { match self.fetch.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => { + Expr::Literal(ScalarValue::Int64(Some(s)), _) => { if s >= 0 { Ok(FetchType::Literal(Some(s as usize))) } else { plan_err!("LIMIT must be >= 0, '{}' was provided", s) } } - Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + Expr::Literal(ScalarValue::Int64(None), _) => { + Ok(FetchType::Literal(None)) + } _ => Ok(FetchType::UnsupportedExpr), }, None => Ok(FetchType::Literal(None)), } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.skip + .iter() + .chain(self.fetch.iter()) + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } } /// Removes duplicate rows from the input @@ -2850,6 +3340,13 @@ impl Distinct { Distinct::On(DistinctOn { input, .. }) => input, } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + Distinct::All(input) => input.stats(), + Distinct::On(distinct_on) => distinct_on.stats(), + } + } } /// Removes duplicate rows from the input @@ -2930,6 +3427,14 @@ impl DistinctOn { self.sort_expr = Some(sort_expr); Ok(self) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.on_expr + .iter() + .chain(self.select_expr.iter()) + .chain(self.sort_expr.iter().flatten().map(|s| &s.expr)) + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -2989,7 +3494,7 @@ impl Aggregate { ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?; @@ -3062,7 +3567,7 @@ impl Aggregate { } fn is_grouping_set(&self) -> bool { - matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_, _)]) } /// Get the output expressions. @@ -3071,7 +3576,7 @@ impl Aggregate { let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; if self.is_grouping_set() { exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { - Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + Expr::column(Column::from_name(Self::INTERNAL_GROUPING_ID)) })); } exprs.extend(self.aggr_expr.iter()); @@ -3120,6 +3625,13 @@ impl Aggregate { /// with `NULL` values. To handle these cases correctly, we must distinguish /// between an actual `NULL` value in a column and a column being excluded from the set. pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.group_expr + .iter() + .chain(self.aggr_expr.iter()) + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3141,7 +3653,7 @@ impl PartialOrd for Aggregate { fn contains_grouping_set(group_expr: &[Expr]) -> bool { group_expr .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))) + .any(|expr| matches!(expr, Expr::GroupingSet(_, _))) } /// Calculates functional dependencies for aggregate expressions. @@ -3187,9 +3699,9 @@ fn calc_func_dependencies_for_project( let proj_indices = exprs .iter() .map(|expr| match expr { - Expr::Wildcard(Wildcard { qualifier, options }) => { + Expr::Wildcard(Wildcard { qualifier, options }, _) => { let wildcard_fields = exprlist_to_fields( - vec![&Expr::Wildcard(Wildcard { + vec![&Expr::wildcard(Wildcard { qualifier: qualifier.clone(), options: options.clone(), })], @@ -3207,7 +3719,7 @@ fn calc_func_dependencies_for_project( .collect::>(), ) } - Expr::Alias(alias) => { + Expr::Alias(alias, _) => { let name = format!("{}", alias.expr); Ok(input_fields .iter() @@ -3247,6 +3759,15 @@ pub struct Sort { pub fetch: Option, } +impl Sort { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.expr + .iter() + .map(|s| &s.expr) + .fold(self.input.stats(), |s, e| s.merge(e.stats())) + } +} + /// Join two logical plans on one or more join columns #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Join { @@ -3277,7 +3798,7 @@ impl Join { column_on: (Vec, Vec), ) -> Result { let original_join = match original { - LogicalPlan::Join(join) => join, + LogicalPlan::Join(join, _) => join, _ => return plan_err!("Could not create join with project input"), }; @@ -3285,7 +3806,7 @@ impl Join { .0 .into_iter() .zip(column_on.1) - .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) + .map(|(l, r)| (Expr::column(l), Expr::column(r))) .collect(); let join_schema = build_join_schema(left.schema(), right.schema(), &original_join.join_type)?; @@ -3301,6 +3822,16 @@ impl Join { null_equals_null: original_join.null_equals_null, }) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.on + .iter() + .flat_map(|(l, r)| [l, r]) + .chain(&self.filter) + .fold(self.left.stats().merge(self.right.stats()), |s, e| { + s.merge(e.stats()) + }) + } } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3357,8 +3888,8 @@ pub struct Subquery { impl Subquery { pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> { match plan { - Expr::ScalarSubquery(it) => Ok(it), - Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()), + Expr::ScalarSubquery(it, _) => Ok(it), + Expr::Cast(cast, _) => Subquery::try_from_expr(cast.expr.as_ref()), _ => plan_err!("Could not coerce into ScalarSubquery!"), } } @@ -3369,6 +3900,12 @@ impl Subquery { outer_ref_columns: self.outer_ref_columns.clone(), } } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.outer_ref_columns + .iter() + .fold(self.subquery.stats(), |s, e| s.merge(e.stats())) + } } impl Debug for Subquery { @@ -3488,6 +4025,12 @@ impl PartialOrd for Unnest { } } +impl Unnest { + pub(crate) fn stats(&self) -> LogicalPlanStats { + self.input.stats() + } +} + #[cfg(test)] mod tests { @@ -3901,7 +4444,7 @@ digraph { let empty_schema = Arc::new(DFSchema::empty()); let p = Projection::try_new_with_schema( vec![col("a")], - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + Arc::new(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&empty_schema), })), @@ -3980,7 +4523,7 @@ digraph { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() .aggregate( - vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("foo")], vec![col("bar")], ]))], @@ -4016,7 +4559,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Arc::new(LogicalPlan::table_scan(TableScan { table_name: TableReference::bare("tab"), source: Arc::clone(&source) as Arc, projection: None, @@ -4027,7 +4570,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::column(col.into()).eq(Expr::literal(ScalarValue::Int32(Some(1)))), scan, ) .unwrap(); @@ -4046,7 +4589,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Arc::new(LogicalPlan::table_scan(TableScan { table_name: TableReference::bare("tab"), source, projection: None, @@ -4057,7 +4600,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = - Filter::try_new(Expr::Column(col.into()).eq(lit(1i32)), scan).unwrap(); + Filter::try_new(Expr::column(col.into()).eq(lit(1i32)), scan).unwrap(); assert!(filter.is_scalar()); } @@ -4081,13 +4624,13 @@ digraph { // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan .transform(|plan| match plan { - LogicalPlan::TableScan(table) => { + LogicalPlan::TableScan(table, _) => { let filter = Filter::try_new( external_filter.clone(), - Arc::new(LogicalPlan::TableScan(table)), + Arc::new(LogicalPlan::table_scan(table)), ) .unwrap(); - Ok(Transformed::yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::filter(filter))) } x => Ok(Transformed::no(x)), }) @@ -4103,12 +4646,12 @@ digraph { #[test] fn test_plan_partial_ord() { - let empty_relation = LogicalPlan::EmptyRelation(EmptyRelation { + let empty_relation = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let describe_table = LogicalPlan::DescribeTable(DescribeTable { + let describe_table = LogicalPlan::describe_table(DescribeTable { schema: Arc::new(Schema::new(vec![Field::new( "foo", DataType::Int32, @@ -4117,7 +4660,7 @@ digraph { output_schema: DFSchemaRef::new(DFSchema::empty()), }); - let describe_table_clone = LogicalPlan::DescribeTable(DescribeTable { + let describe_table_clone = LogicalPlan::describe_table(DescribeTable { schema: Arc::new(Schema::new(vec![Field::new( "foo", DataType::Int32, @@ -4139,12 +4682,12 @@ digraph { #[test] fn test_limit_with_new_children() { - let limit = LogicalPlan::Limit(Limit { + let limit = LogicalPlan::limit(Limit { skip: None, - fetch: Some(Box::new(Expr::Literal( + fetch: Some(Box::new(Expr::literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), ))), - input: Arc::new(LogicalPlan::Values(Values { + input: Arc::new(LogicalPlan::values(Values { schema: Arc::new(DFSchema::empty()), values: vec![vec![]], })), diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 26df379f5e4ad..ca301b9711ccc 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -20,6 +20,7 @@ use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; use std::sync::{Arc, OnceLock}; +use crate::logical_plan::tree_node::LogicalPlanStats; use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Statements have a unchanging empty schema. @@ -130,6 +131,16 @@ impl Statement { } Wrapper(self) } + + pub(crate) fn stats(&self) -> LogicalPlanStats { + match self { + Statement::Prepare(Prepare { input, .. }) => input.stats(), + Statement::Execute(Execute { parameters, .. }) => parameters + .iter() + .fold(LogicalPlanStats::empty(), |s, e| s.merge(e.stats())), + _ => LogicalPlanStats::empty(), + } + } } /// Indicates if a transaction was committed or aborted diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 6850c30f4f81b..306deec83ef77 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -53,6 +53,108 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, Result}; +use enumset::{EnumSet, EnumSetType}; + +#[derive(EnumSetType, Debug)] +pub enum LogicalPlanPattern { + /// [`Expr`] + // ExprAlias, + ExprColumn, + // ExprScalarVariable, + ExprLiteral, + ExprBinaryExpr, + ExprLike, + // ExprSimilarTo, + ExprNot, + ExprIsNotNull, + ExprIsNull, + // ExprIsTrue, + // ExprIsFalse, + // ExprIsUnknown, + // ExprIsNotTrue, + // ExprIsNotFalse, + // ExprIsNotUnknown, + ExprNegative, + // ExprGetIndexedField, + ExprBetween, + ExprCase, + ExprCast, + ExprTryCast, + ExprScalarFunction, + ExprAggregateFunction, + ExprWindowFunction, + ExprInList, + // ExprExists, + // ExprInSubquery, + // ExprScalarSubquery, + // ExprWildcard, + // ExprGroupingSet, + // ExprPlaceholder, + // ExprOuterReferenceColumn, + // ExprUnnest, + + // /// [`LogicalPlan`] + // LogicalPlanProjection, + // LogicalPlanFilter, + // LogicalPlanWindow, + // LogicalPlanAggregate, + // LogicalPlanSort, + // LogicalPlanJoin, + // LogicalPlanCrossJoin, + // LogicalPlanRepartition, + // LogicalPlanUnion, + // LogicalPlanTableScan, + // LogicalPlanEmptyRelation, + // LogicalPlanSubquery, + // LogicalPlanSubqueryAlias, + // LogicalPlanLimit, + // LogicalPlanStatement, + // LogicalPlanValues, + // LogicalPlanExplain, + // LogicalPlanAnalyze, + // LogicalPlanExtension, + // LogicalPlanDistinct, + // LogicalPlanDml, + // LogicalPlanDdl, + // LogicalPlanCopy, + // LogicalPlanDescribeTable, + // LogicalPlanUnnest, + // LogicalPlanRecursiveQuery, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct LogicalPlanStats { + patterns: EnumSet, +} + +impl LogicalPlanStats { + pub(crate) fn new(patterns: EnumSet) -> Self { + Self { patterns } + } + + pub(crate) fn empty() -> Self { + Self { + patterns: EnumSet::empty(), + } + } + + pub(crate) fn merge(mut self, other: LogicalPlanStats) -> Self { + self.patterns.insert_all(other.patterns); + self + } + + pub fn contains_pattern(&self, pattern: LogicalPlanPattern) -> bool { + self.patterns.contains(pattern) + } + + pub fn contains_all_patterns(&self, patterns: EnumSet) -> bool { + self.patterns.is_superset(patterns) + } + + pub fn contains_any_patterns(&self, patterns: EnumSet) -> bool { + !self.patterns.is_disjoint(patterns) + } +} impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -75,75 +177,93 @@ impl TreeNode for LogicalPlan { f: F, ) -> Result> { Ok(match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Projection(Projection { + LogicalPlan::Projection( + Projection { + expr, + input, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::projection(Projection { expr, input, schema, }) }), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Filter(Filter { + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::filter(Filter { predicate, input, having, }) }), - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Repartition(Repartition { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::repartition(Repartition { input, partitioning_scheme, }) }), - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Window(Window { + LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::window(Window { input, window_expr, schema, }) }), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Aggregate(Aggregate { + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::aggregate(Aggregate { input, group_expr, aggr_expr, schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => input + LogicalPlan::Sort(Sort { expr, input, fetch }, _) => input .map_elements(f)? - .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) => (left, right).map_elements(f)?.update_data(|(left, right)| { - LogicalPlan::Join(Join { + .update_data(|input| LogicalPlan::sort(Sort { expr, input, fetch })), + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => (left, right).map_elements(f)?.update_data(|(left, right)| { + LogicalPlan::join(Join { left, right, on, @@ -154,35 +274,43 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => input + LogicalPlan::Limit(Limit { skip, fetch, input }, _) => input .map_elements(f)? - .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), - LogicalPlan::Subquery(Subquery { - subquery, - outer_ref_columns, - }) => subquery.map_elements(f)?.update_data(|subquery| { - LogicalPlan::Subquery(Subquery { + .update_data(|input| LogicalPlan::limit(Limit { skip, fetch, input })), + LogicalPlan::Subquery( + Subquery { + subquery, + outer_ref_columns, + }, + _, + ) => subquery.map_elements(f)?.update_data(|subquery| { + LogicalPlan::subquery(Subquery { subquery, outer_ref_columns, }) }), - LogicalPlan::SubqueryAlias(SubqueryAlias { - input, - alias, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::SubqueryAlias(SubqueryAlias { + LogicalPlan::SubqueryAlias( + SubqueryAlias { + input, + alias, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::subquery_alias(SubqueryAlias { input, alias, schema, }) }), - LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? - .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => inputs + LogicalPlan::Extension(extension, _) => { + rewrite_extension_inputs(extension, f)? + .update_data(LogicalPlan::extension) + } + LogicalPlan::Union(Union { inputs, schema }, _) => inputs .map_elements(f)? - .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), - LogicalPlan::Distinct(distinct) => match distinct { + .update_data(|inputs| LogicalPlan::union(Union { inputs, schema })), + LogicalPlan::Distinct(distinct, _) => match distinct { Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, @@ -200,15 +328,18 @@ impl TreeNode for LogicalPlan { }) }), } - .update_data(LogicalPlan::Distinct), - LogicalPlan::Explain(Explain { - verbose, - plan, - stringified_plans, - schema, - logical_optimization_succeeded, - }) => plan.map_elements(f)?.update_data(|plan| { - LogicalPlan::Explain(Explain { + .update_data(LogicalPlan::distinct), + LogicalPlan::Explain( + Explain { + verbose, + plan, + stringified_plans, + schema, + logical_optimization_succeeded, + }, + _, + ) => plan.map_elements(f)?.update_data(|plan| { + LogicalPlan::explain(Explain { verbose, plan, stringified_plans, @@ -216,25 +347,31 @@ impl TreeNode for LogicalPlan { logical_optimization_succeeded, }) }), - LogicalPlan::Analyze(Analyze { - verbose, - input, - schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Analyze(Analyze { + LogicalPlan::Analyze( + Analyze { + verbose, + input, + schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::analyze(Analyze { verbose, input, schema, }) }), - LogicalPlan::Dml(DmlStatement { - table_name, - table_schema, - op, - input, - output_schema, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Dml(DmlStatement { + LogicalPlan::Dml( + DmlStatement { + table_name, + table_schema, + op, + input, + output_schema, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::dml(DmlStatement { table_name, table_schema, op, @@ -242,14 +379,17 @@ impl TreeNode for LogicalPlan { output_schema, }) }), - LogicalPlan::Copy(CopyTo { - input, - output_url, - partition_by, - file_type, - options, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Copy(CopyTo { + LogicalPlan::Copy( + CopyTo { + input, + output_url, + partition_by, + file_type, + options, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::copy(CopyTo { input, output_url, partition_by, @@ -257,7 +397,7 @@ impl TreeNode for LogicalPlan { options, }) }), - LogicalPlan::Ddl(ddl) => { + LogicalPlan::Ddl(ddl, _) => { match ddl { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -304,18 +444,21 @@ impl TreeNode for LogicalPlan { | DdlStatement::CreateFunction(_) | DdlStatement::DropFunction(_) => Transformed::no(ddl), } - .update_data(LogicalPlan::Ddl) + .update_data(LogicalPlan::ddl) } - LogicalPlan::Unnest(Unnest { - input, - exec_columns: input_columns, - list_type_columns, - struct_type_columns, - dependency_indices, - schema, - options, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Unnest(Unnest { + LogicalPlan::Unnest( + Unnest { + input, + exec_columns: input_columns, + list_type_columns, + struct_type_columns, + dependency_indices, + schema, + options, + }, + _, + ) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::unnest(Unnest { input, exec_columns: input_columns, dependency_indices, @@ -325,14 +468,17 @@ impl TreeNode for LogicalPlan { options, }) }), - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) => (static_term, recursive_term).map_elements(f)?.update_data( + LogicalPlan::RecursiveQuery( + RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }, + _, + ) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { + LogicalPlan::recursive_query(RecursiveQuery { name, static_term, recursive_term, @@ -340,19 +486,19 @@ impl TreeNode for LogicalPlan { }) }, ), - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Prepare(p) => p .input .map_elements(f)? .update_data(|input| Statement::Prepare(Prepare { input, ..p })), _ => Transformed::no(stmt), } - .update_data(LogicalPlan::Statement), + .update_data(LogicalPlan::statement), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } - | LogicalPlan::DescribeTable(_) => Transformed::no(self), + | LogicalPlan::DescribeTable(_, _) => Transformed::no(self), }) } } @@ -403,78 +549,87 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), - LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), - LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { + LogicalPlan::Projection(Projection { expr, .. }, _) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }, _) => values.apply_elements(f), + LogicalPlan::Filter(Filter { predicate, .. }, _) => f(predicate), + LogicalPlan::Repartition( + Repartition { + partitioning_scheme, + .. + }, + _, + ) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, - LogicalPlan::Window(Window { window_expr, .. }) => { + LogicalPlan::Window(Window { window_expr, .. }, _) => { window_expr.apply_elements(f) } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => (group_expr, aggr_expr).apply_ref_elements(f), + LogicalPlan::Aggregate( + Aggregate { + group_expr, + aggr_expr, + .. + }, + _, + ) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { on, filter, .. }) => { + LogicalPlan::Join(Join { on, filter, .. }, _) => { (on, filter).apply_ref_elements(f) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), - LogicalPlan::Extension(extension) => { + LogicalPlan::Sort(Sort { expr, .. }, _) => expr.apply_elements(f), + LogicalPlan::Extension(extension, _) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs extension.node.expressions().apply_elements(f) } - LogicalPlan::TableScan(TableScan { filters, .. }) => { + LogicalPlan::TableScan(TableScan { filters, .. }, _) => { filters.apply_elements(f) } - LogicalPlan::Unnest(unnest) => { + LogicalPlan::Unnest(unnest, _) => { let columns = unnest.exec_columns.clone(); let exprs = columns .iter() - .map(|c| Expr::Column(c.clone())) + .map(|c| Expr::column(c.clone())) .collect::>(); exprs.apply_elements(f) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - .. - })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }), + _, + ) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }, _) => { (skip, fetch).apply_ref_elements(f) } - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Execute(Execute { parameters, .. }) => { parameters.apply_elements(f) } _ => Ok(TreeNodeRecursion::Continue), }, // plans without expressions - LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), + LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::DescribeTable(_, _) => Ok(TreeNodeRecursion::Continue), } } @@ -490,35 +645,44 @@ impl LogicalPlan { mut f: F, ) -> Result> { Ok(match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => expr.map_elements(f)?.update_data(|expr| { - LogicalPlan::Projection(Projection { + LogicalPlan::Projection( + Projection { + expr, + input, + schema, + }, + _, + ) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::projection(Projection { expr, input, schema, }) }), - LogicalPlan::Values(Values { schema, values }) => values + LogicalPlan::Values(Values { schema, values }, _) => values .map_elements(f)? - .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => f(predicate)?.update_data(|predicate| { - LogicalPlan::Filter(Filter { + .update_data(|values| LogicalPlan::values(Values { schema, values })), + LogicalPlan::Filter( + Filter { + predicate, + input, + having, + }, + _, + ) => f(predicate)?.update_data(|predicate| { + LogicalPlan::filter(Filter { predicate, input, having, }) }), - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => match partitioning_scheme { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), @@ -528,30 +692,36 @@ impl LogicalPlan { Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } .update_data(|partitioning_scheme| { - LogicalPlan::Repartition(Repartition { + LogicalPlan::repartition(Repartition { input, partitioning_scheme, }) }), - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => window_expr.map_elements(f)?.update_data(|window_expr| { - LogicalPlan::Window(Window { + LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::window(Window { input, window_expr, schema, }) }), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + }, + _, + ) => (group_expr, aggr_expr).map_elements(f)?.update_data( |(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { + LogicalPlan::aggregate(Aggregate { input, group_expr, aggr_expr, @@ -563,17 +733,20 @@ impl LogicalPlan { // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { - LogicalPlan::Join(Join { + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { + LogicalPlan::join(Join { left, right, on, @@ -584,14 +757,14 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + LogicalPlan::Sort(Sort { expr, input, fetch }, _) => expr .map_elements(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), - LogicalPlan::Extension(Extension { node }) => { + .update_data(|expr| LogicalPlan::sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }, _) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs let exprs = node.expressions().map_elements(f)?; - let plan = LogicalPlan::Extension(Extension { + let plan = LogicalPlan::extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), exprs.data, @@ -600,15 +773,18 @@ impl LogicalPlan { }); Transformed::new(plan, exprs.transformed, exprs.tnr) } - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) => filters.map_elements(f)?.update_data(|filters| { - LogicalPlan::TableScan(TableScan { + LogicalPlan::TableScan( + TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }, + _, + ) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::table_scan(TableScan { table_name, source, projection, @@ -617,16 +793,19 @@ impl LogicalPlan { fetch, }) }), - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) => (on_expr, select_expr, sort_expr) + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + }), + _, + ) => (on_expr, select_expr, sort_expr) .map_elements(f)? .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { + LogicalPlan::distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, @@ -634,12 +813,12 @@ impl LogicalPlan { schema, })) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => { + LogicalPlan::Limit(Limit { skip, fetch, input }, _) => { (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { - LogicalPlan::Limit(Limit { skip, fetch, input }) + LogicalPlan::limit(Limit { skip, fetch, input }) }) } - LogicalPlan::Statement(stmt) => match stmt { + LogicalPlan::Statement(stmt, _) => match stmt { Statement::Execute(e) => { e.parameters.map_elements(f)?.update_data(|parameters| { Statement::Execute(Execute { parameters, ..e }) @@ -647,21 +826,21 @@ impl LogicalPlan { } _ => Transformed::no(stmt), } - .update_data(LogicalPlan::Statement), + .update_data(LogicalPlan::statement), // plans without expressions - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) => Transformed::no(self), + LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::Unnest(_, _) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::DescribeTable(_, _) => Transformed::no(self), }) } @@ -821,13 +1000,13 @@ impl LogicalPlan { ) -> Result { self.apply_expressions(|expr| { expr.apply(|expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { + Expr::Exists(Exists { subquery, .. }, _) + | Expr::InSubquery(InSubquery { subquery, .. }, _) + | Expr::ScalarSubquery(subquery, _) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is // actually a Subquery alias) - f(&LogicalPlan::Subquery(subquery.clone())) + f(&LogicalPlan::subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), }) @@ -844,30 +1023,35 @@ impl LogicalPlan { ) -> Result> { self.map_expressions(|expr| { expr.transform_down(|expr| match expr { - Expr::Exists(Exists { subquery, negated }) => { - f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::Exists(Exists { subquery, negated })) + Expr::Exists(Exists { subquery, negated }, _) => { + f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::exists(Exists { subquery, negated })) } _ => internal_err!("Transformation should return Subquery"), }) } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + Expr::InSubquery( + InSubquery { expr, subquery, negated, - })), + }, + _, + ) => f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::in_subquery(InSubquery { + expr, + subquery, + negated, + })) + } _ => internal_err!("Transformation should return Subquery"), }), - Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + Expr::ScalarSubquery(subquery, _) => f(LogicalPlan::subquery(subquery))? .map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::ScalarSubquery(subquery)) + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::scalar_subquery(subquery)) } _ => internal_err!("Transformation should return Subquery"), }), @@ -875,4 +1059,34 @@ impl LogicalPlan { }) }) } + + pub fn stats(&self) -> LogicalPlanStats { + match self { + LogicalPlan::Projection(_, stats) => *stats, + LogicalPlan::Filter(_, stats) => *stats, + LogicalPlan::Window(_, stats) => *stats, + LogicalPlan::Aggregate(_, stats) => *stats, + LogicalPlan::Sort(_, stats) => *stats, + LogicalPlan::Join(_, stats) => *stats, + LogicalPlan::Repartition(_, stats) => *stats, + LogicalPlan::Union(_, stats) => *stats, + LogicalPlan::TableScan(_, stats) => *stats, + LogicalPlan::EmptyRelation(_, stats) => *stats, + LogicalPlan::Subquery(_, stats) => *stats, + LogicalPlan::SubqueryAlias(_, stats) => *stats, + LogicalPlan::Limit(_, stats) => *stats, + LogicalPlan::Statement(_, stats) => *stats, + LogicalPlan::Values(_, stats) => *stats, + LogicalPlan::Explain(_, stats) => *stats, + LogicalPlan::Analyze(_, stats) => *stats, + LogicalPlan::Extension(_, stats) => *stats, + LogicalPlan::Distinct(_, stats) => *stats, + LogicalPlan::Dml(_, stats) => *stats, + LogicalPlan::Ddl(_, stats) => *stats, + LogicalPlan::Copy(_, stats) => *stats, + LogicalPlan::DescribeTable(_, stats) => *stats, + LogicalPlan::Unnest(_, stats) => *stats, + LogicalPlan::RecursiveQuery(_, stats) => *stats, + } + } } diff --git a/datafusion/expr/src/operation.rs b/datafusion/expr/src/operation.rs index 6b79a8248b293..fd532bde9d690 100644 --- a/datafusion/expr/src/operation.rs +++ b/datafusion/expr/src/operation.rs @@ -117,7 +117,7 @@ impl ops::Neg for Expr { type Output = Self; fn neg(self) -> Self::Output { - Expr::Negative(Box::new(self)) + Expr::negative(Box::new(self)) } } @@ -127,33 +127,33 @@ impl Not for Expr { fn not(self) -> Self::Output { match self { - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::Like(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::SimilarTo(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - _ => Expr::Not(Box::new(self)), + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + stats, + ) => Expr::Like( + Like::new(!negated, expr, pattern, escape_char, case_insensitive), + stats, + ), + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + stats, + ) => Expr::SimilarTo( + Like::new(!negated, expr, pattern, escape_char, case_insensitive), + stats, + ), + _ => Expr::_not(Box::new(self)), } } } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 262aa99e50075..f8cdce45fbfff 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -60,7 +60,7 @@ macro_rules! create_func { create_func!(Sum, sum_udaf); pub fn sum(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![expr], false, @@ -73,7 +73,7 @@ pub fn sum(expr: Expr) -> Expr { create_func!(Count, count_udaf); pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( count_udaf(), vec![expr], false, @@ -86,7 +86,7 @@ pub fn count(expr: Expr) -> Expr { create_func!(Avg, avg_udaf); pub fn avg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( avg_udaf(), vec![expr], false, @@ -284,7 +284,7 @@ impl AggregateUDFImpl for Count { create_func!(Min, min_udaf); pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( min_udaf(), vec![expr], false, @@ -369,7 +369,7 @@ impl AggregateUDFImpl for Min { create_func!(Max, max_udaf); pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( max_udaf(), vec![expr], false, diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed0461..882565ad2f02e 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -23,6 +23,7 @@ use crate::expr::{ }; use crate::{Expr, ExprFunctionExt}; +use crate::logical_plan::tree_node::LogicalPlanStats; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; @@ -43,61 +44,61 @@ impl TreeNode for Expr { f: F, ) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Unnest(Unnest { expr }) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsTrue(expr) - | Expr::IsFalse(expr) - | Expr::IsUnknown(expr) - | Expr::IsNotTrue(expr) - | Expr::IsNotFalse(expr) - | Expr::IsNotUnknown(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast(Cast { expr, .. }) - | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), - Expr::ScalarFunction(ScalarFunction { args, .. }) => { + Expr::Alias(Alias { expr, .. }, _) + | Expr::Unnest(Unnest { expr }, _) + | Expr::Not(expr, _) + | Expr::IsNotNull(expr, _) + | Expr::IsTrue(expr, _) + | Expr::IsFalse(expr, _) + | Expr::IsUnknown(expr, _) + | Expr::IsNotTrue(expr, _) + | Expr::IsNotFalse(expr, _) + | Expr::IsNotUnknown(expr, _) + | Expr::IsNull(expr, _) + | Expr::Negative(expr, _) + | Expr::Cast(Cast { expr, .. }, _) + | Expr::TryCast(TryCast { expr, .. }, _) + | Expr::InSubquery(InSubquery { expr, .. }, _) => expr.apply_elements(f), + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) + | Expr::GroupingSet(GroupingSet::Cube(exprs), _) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }, _) => { args.apply_elements(f) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { lists_of_exprs.apply_elements(f) } - Expr::Column(_) + Expr::Column(_, _) // Treat OuterReferenceColumn as a leaf expression - | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::OuterReferenceColumn(_, _, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Literal(_, _) | Expr::Exists { .. } - | Expr::ScalarSubquery(_) + | Expr::ScalarSubquery(_, _) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), - Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + | Expr::Placeholder(_, _) => Ok(TreeNodeRecursion::Continue), + Expr::BinaryExpr(BinaryExpr { left, right, .. }, _) => { (left, right).apply_ref_elements(f) } - Expr::Like(Like { expr, pattern, .. }) - | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Expr::Like(Like { expr, pattern, .. }, _) + | Expr::SimilarTo(Like { expr, pattern, .. }, _) => { (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { expr, low, high, .. - }) => (expr, low, high).apply_ref_elements(f), - Expr::Case(Case { expr, when_then_expr, else_expr }) => + }, _) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }, _) => (expr, when_then_expr, else_expr).apply_ref_elements(f), - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }, _) => (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. - }) => { + }, _) => { (args, partition_by, order_by).apply_ref_elements(f) } - Expr::InList(InList { expr, list, .. }) => { + Expr::InList(InList { expr, list, .. }, _) => { (expr, list).apply_ref_elements(f) } } @@ -112,45 +113,54 @@ impl TreeNode for Expr { mut f: F, ) -> Result> { Ok(match self { - Expr::Column(_) + Expr::Column(_, _) | Expr::Wildcard { .. } - | Expr::Placeholder(Placeholder { .. }) - | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(Placeholder { .. }, _) + | Expr::OuterReferenceColumn(_, _, _) | Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => expr + | Expr::ScalarSubquery(_, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Literal(_, _) => Transformed::no(self), + Expr::Unnest(Unnest { expr, .. }, _) => expr .map_elements(f)? - .update_data(|expr| Expr::Unnest(Unnest { expr })), - Expr::Alias(Alias { - expr, - relation, - name, - }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => expr.map_elements(f)?.update_data(|be| { - Expr::InSubquery(InSubquery::new(be, subquery, negated)) + .update_data(|expr| Expr::unnest(Unnest { expr })), + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => expr.map_elements(f)?.update_data(|be| { + Expr::in_subquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => (left, right) .map_elements(f)? .update_data(|(new_left, new_right)| { - Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + Expr::binary_expr(BinaryExpr::new(new_left, op, new_right)) }), - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { (expr, pattern) .map_elements(f)? .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( + Expr::_like(Like::new( negated, new_expr, new_pattern, @@ -159,17 +169,20 @@ impl TreeNode for Expr { )) }) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { (expr, pattern) .map_elements(f)? .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( + Expr::similar_to(Like::new( negated, new_expr, new_pattern, @@ -178,60 +191,77 @@ impl TreeNode for Expr { )) }) } - Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), - Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), - Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), - Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), - Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), - Expr::IsNotUnknown(expr) => { - expr.map_elements(f)?.update_data(Expr::IsNotUnknown) + Expr::Not(expr, _) => expr.map_elements(f)?.update_data(Expr::_not), + Expr::IsNotNull(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_null) } - Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => (expr, low, high).map_elements(f)?.update_data( + Expr::IsNull(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_null), + Expr::IsTrue(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_true), + Expr::IsFalse(expr, _) => expr.map_elements(f)?.update_data(Expr::_is_false), + Expr::IsUnknown(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_unknown) + } + Expr::IsNotTrue(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_true) + } + Expr::IsNotFalse(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_false) + } + Expr::IsNotUnknown(expr, _) => { + expr.map_elements(f)?.update_data(Expr::_is_not_unknown) + } + Expr::Negative(expr, _) => expr.map_elements(f)?.update_data(Expr::negative), + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => (expr, low, high).map_elements(f)?.update_data( |(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + Expr::_between(Between::new(new_expr, negated, new_low, new_high)) }, ), - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => (expr, when_then_expr, else_expr) + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => (expr, when_then_expr, else_expr) .map_elements(f)? .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + Expr::case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, data_type }, _) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }, _) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + .update_data(|be| Expr::try_cast(TryCast::new(be, data_type))), + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { args.map_elements(f)?.map_data(|new_args| { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + Ok(Expr::scalar_function(ScalarFunction::new_udf( func, new_args, ))) })? } - Expr::WindowFunction(WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - null_treatment, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( + Expr::WindowFunction( + WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) + Expr::window_function(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) .window_frame(window_frame) @@ -240,16 +270,19 @@ impl TreeNode for Expr { .unwrap() }, ), - Expr::AggregateFunction(AggregateFunction { - args, - func, - distinct, - filter, - order_by, - null_treatment, - }) => (args, filter, order_by).map_elements(f)?.map_data( + Expr::AggregateFunction( + AggregateFunction { + args, + func, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, new_args, distinct, @@ -259,28 +292,69 @@ impl TreeNode for Expr { ))) }, )?, - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => exprs - .map_elements(f)? - .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => exprs - .map_elements(f)? - .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + Expr::GroupingSet(grouping_set, _) => match grouping_set { + GroupingSet::Rollup(exprs) => { + exprs.map_elements(f)?.update_data(GroupingSet::Rollup) + } + GroupingSet::Cube(exprs) => { + exprs.map_elements(f)?.update_data(GroupingSet::Cube) + } GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs .map_elements(f)? - .update_data(|new_lists_of_exprs| { - Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) - }), - }, - Expr::InList(InList { - expr, - list, - negated, - }) => (expr, list) + .update_data(GroupingSet::GroupingSets), + } + .update_data(Expr::grouping_set), + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => (expr, list) .map_elements(f)? .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) + Expr::_in_list(InList::new(new_expr, new_list, negated)) }), }) } } +impl Expr { + pub fn stats(&self) -> LogicalPlanStats { + match self { + Expr::Alias(_, stats) => *stats, + Expr::Column(_, stats) => *stats, + Expr::ScalarVariable(_, _, stats) => *stats, + Expr::Literal(_, stats) => *stats, + Expr::BinaryExpr(_, stats) => *stats, + Expr::Like(_, stats) => *stats, + Expr::SimilarTo(_, stats) => *stats, + Expr::Not(_, stats) => *stats, + Expr::IsNotNull(_, stats) => *stats, + Expr::IsNull(_, stats) => *stats, + Expr::IsTrue(_, stats) => *stats, + Expr::IsFalse(_, stats) => *stats, + Expr::IsUnknown(_, stats) => *stats, + Expr::IsNotTrue(_, stats) => *stats, + Expr::IsNotFalse(_, stats) => *stats, + Expr::IsNotUnknown(_, stats) => *stats, + Expr::Negative(_, stats) => *stats, + Expr::Between(_, stats) => *stats, + Expr::Case(_, stats) => *stats, + Expr::Cast(_, stats) => *stats, + Expr::TryCast(_, stats) => *stats, + Expr::ScalarFunction(_, stats) => *stats, + Expr::AggregateFunction(_, stats) => *stats, + Expr::WindowFunction(_, stats) => *stats, + Expr::InList(_, stats) => *stats, + Expr::Exists(_, stats) => *stats, + Expr::InSubquery(_, stats) => *stats, + Expr::ScalarSubquery(_, stats) => *stats, + Expr::Wildcard(_, stats) => *stats, + Expr::GroupingSet(_, stats) => *stats, + Expr::Placeholder(_, stats) => *stats, + Expr::OuterReferenceColumn(_, _, stats) => *stats, + Expr::Unnest(_, stats) => *stats, + } + } +} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 28506caceea28..83192a1c8ed12 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -145,7 +145,7 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( Arc::new(self.clone()), args, false, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9a588cf43be67..b3ac8d36cd52d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -123,7 +123,7 @@ impl ScalarUDF { /// let expr = my_func.call(vec![col("a"), lit(12.3)]); /// ``` pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Expr::scalar_function(crate::expr::ScalarFunction::new_udf( Arc::new(self.clone()), args, )) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 11bc7043da761..b3c227c31b0b3 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -130,7 +130,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::window_function(WindowFunction::new(fun, args)) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a1ab142fa8355..489875b5e39fd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -61,7 +61,7 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if let Some(Expr::GroupingSet(grouping_set, _)) = group_expr.first() { if group_expr.len() > 1 { return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" @@ -201,7 +201,7 @@ fn cross_join_grouping_sets( pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { let has_grouping_set = group_expr .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))); + .any(|expr| matches!(expr, Expr::GroupingSet(_, _))); if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } @@ -210,17 +210,17 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { .iter() .map(|expr| { let exprs = match expr { - Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets), _) => { check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets.iter().map(|e| e.iter().collect()).collect() } - Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(group_exprs), _) => { let grouping_sets = powerset(group_exprs) .map_err(|e| plan_datafusion_err!("{}", e))?; check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets } - Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(group_exprs), _) => { let size = group_exprs.len(); let slice = group_exprs.as_slice(); check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?; @@ -247,7 +247,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .unwrap_or_default(); - Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets( + Ok(vec![Expr::grouping_set(GroupingSet::GroupingSets( grouping_sets, ))]) } @@ -255,7 +255,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { /// Find all distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if let Some(Expr::GroupingSet(grouping_set, _)) = group_expr.first() { if group_expr.len() > 1 { return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" @@ -276,29 +276,29 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { expr.apply(|expr| { match expr { - Expr::Column(qc) => { + Expr::Column(qc, _) => { accum.insert(qc.clone()); } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::Unnest(_) - | Expr::ScalarVariable(_, _) - | Expr::Alias(_) - | Expr::Literal(_) + Expr::Unnest(_, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Alias(_, _) + | Expr::Literal(_, _) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) | Expr::Between { .. } | Expr::Case { .. } | Expr::Cast { .. } @@ -306,13 +306,13 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::GroupingSet(_) + | Expr::GroupingSet(_, _) | Expr::InList { .. } | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) | Expr::Wildcard { .. } - | Expr::Placeholder(_) + | Expr::Placeholder(_, _) | Expr::OuterReferenceColumn { .. } => {} } Ok(TreeNodeRecursion::Continue) @@ -370,7 +370,7 @@ fn get_exprs_except_skipped( .iter() .filter_map(|c| { if !columns_to_skip.contains(c) { - Some(Expr::Column(c.clone())) + Some(Expr::column(c.clone())) } else { None } @@ -582,7 +582,7 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }, _) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -703,7 +703,7 @@ pub fn exprlist_to_fields<'a>( let result = exprs .into_iter() .map(|e| match e { - Expr::Wildcard(Wildcard { qualifier, options }) => match qualifier { + Expr::Wildcard(Wildcard { qualifier, options }, _) => match qualifier { None => { let excluded: Vec = get_excluded_columns( options.exclude.as_ref(), @@ -769,18 +769,18 @@ pub fn exprlist_to_fields<'a>( /// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields. pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { - LogicalPlan::Window(window) => find_base_plan(&window.input), - LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + LogicalPlan::Window(window, _) => find_base_plan(&window.input), + LogicalPlan::Aggregate(agg, _) => find_base_plan(&agg.input), // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection // We should expand the wildcard expression based on the input plan of the inner Projection. - LogicalPlan::Unnest(unnest) => { - if let LogicalPlan::Projection(projection) = unnest.input.deref() { + LogicalPlan::Unnest(unnest, _) => { + if let LogicalPlan::Projection(projection, _) = unnest.input.deref() { find_base_plan(&projection.input) } else { input } } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { if filter.having { // If a filter is used for a having clause, its input plan is an aggregation. // We should expand the wildcard expression based on the aggregation's input plan. @@ -802,10 +802,13 @@ pub fn exprlist_len( exprs .iter() .map(|e| match e { - Expr::Wildcard(Wildcard { - qualifier: None, - options, - }) => { + Expr::Wildcard( + Wildcard { + qualifier: None, + options, + }, + _, + ) => { let excluded = get_excluded_columns( options.exclude.as_ref(), options.except.as_ref(), @@ -819,10 +822,13 @@ pub fn exprlist_len( .len(), ) } - Expr::Wildcard(Wildcard { - qualifier: Some(qualifier), - options, - }) => { + Expr::Wildcard( + Wildcard { + qualifier: Some(qualifier), + options, + }, + _, + ) => { let related_wildcard_schema = wildcard_schema.as_ref().map_or_else( || Ok(Arc::clone(schema)), |schema| { @@ -886,7 +892,7 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); e.transform_down(|node: Expr| match exprs_map.get(&node) { Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), + Expr::column(column.clone()), true, TreeNodeRecursion::Jump, )), @@ -901,14 +907,14 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { exprs .iter() .flat_map(find_columns_referenced_by_expr) - .map(Expr::Column) + .map(Expr::column) .collect() } pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; e.apply(|expr| { - if let Expr::Column(c) = expr { + if let Expr::Column(c, _) = expr { exprs.push(c.clone()) } Ok(TreeNodeRecursion::Continue) @@ -921,11 +927,11 @@ pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { /// Convert any `Expr` to an `Expr::Column`. pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { match expr { - Expr::Column(col) => { + Expr::Column(col, _) => { let (qualifier, field) = plan.schema().qualified_field_from_column(col)?; Ok(Expr::from(Column::from((qualifier, field)))) } - _ => Ok(Expr::Column(Column::from_name( + _ => Ok(Expr::column(Column::from_name( expr.schema_name().to_string(), ))), } @@ -940,12 +946,12 @@ pub(crate) fn find_column_indexes_referenced_by_expr( let mut indexes = vec![]; e.apply(|expr| { match expr { - Expr::Column(qc) => { + Expr::Column(qc, _) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } } - Expr::Literal(_) => { + Expr::Literal(_, _) => { indexes.push(usize::MAX); } _ => {} @@ -1096,15 +1102,18 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { let exprs = split_conjunction_impl(left, exprs); split_conjunction_impl(right, exprs) } - Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + Expr::Alias(Alias { expr, .. }, _) => split_conjunction_impl(expr, exprs), other => { exprs.push(other); exprs @@ -1120,15 +1129,18 @@ pub fn iter_conjunction(expr: &Expr) -> impl Iterator { std::iter::from_fn(move || { while let Some(expr) = stack.pop() { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { stack.push(right); stack.push(left); } - Expr::Alias(Alias { expr, .. }) => stack.push(expr), + Expr::Alias(Alias { expr, .. }, _) => stack.push(expr), other => return Some(other), } } @@ -1144,15 +1156,18 @@ pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { std::iter::from_fn(move || { while let Some(expr) = stack.pop() { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { + Expr::BinaryExpr( + BinaryExpr { + right, + op: Operator::And, + left, + }, + _, + ) => { stack.push(*right); stack.push(*left); } - Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + Expr::Alias(Alias { expr, .. }, _) => stack.push(*expr), other => return Some(other), } } @@ -1217,11 +1232,11 @@ fn split_binary_owned_impl( mut exprs: Vec, ) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }, _) if op == operator => { let exprs = split_binary_owned_impl(*left, operator, exprs); split_binary_owned_impl(*right, operator, exprs) } - Expr::Alias(Alias { expr, .. }) => { + Expr::Alias(Alias { expr, .. }, _) => { split_binary_owned_impl(*expr, operator, exprs) } other => { @@ -1244,11 +1259,11 @@ fn split_binary_impl<'a>( mut exprs: Vec<&'a Expr>, ) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }, _) if *op == operator => { let exprs = split_binary_impl(left, operator, exprs); split_binary_impl(right, operator, exprs) } - Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + Expr::Alias(Alias { expr, .. }, _) => split_binary_impl(expr, operator, exprs), other => { exprs.push(other); exprs @@ -1331,7 +1346,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { for filter in exprs.into_iter() { // If the expression contains correlated predicates, add it to join filters if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }, _) if left.eq(right)) { joins.push(strip_outer_reference((*filter).clone())); } @@ -1422,19 +1437,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1452,25 +1467,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) @@ -1802,11 +1817,11 @@ mod tests { fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; assert_eq!(1, accum.len()); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 4b1ab323b8f44..c48306551c25f 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -69,7 +69,7 @@ make_udaf_expr_and_func!( ); pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( count_udaf(), vec![expr], true, diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index ffb5183278e67..aebcfbbe409ac 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -22,7 +22,7 @@ macro_rules! make_udaf_expr { pub fn $EXPR_FN( $($arg: datafusion_expr::Expr,)* ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + datafusion_expr::Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), vec![$($arg),*], false, @@ -45,7 +45,7 @@ macro_rules! make_udaf_expr_and_func { pub fn $EXPR_FN( args: Vec, ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + datafusion_expr::Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), args, false, diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3c4a09c659925..3a753b9f226ee 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -58,8 +58,8 @@ fn criterion_benchmark(c: &mut Criterion) { let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::literal(ScalarValue::Utf8(Some(keys[i].clone())))); + buffer.push(Expr::literal(ScalarValue::Int32(Some(values[i])))); } let planner = NestedFunctionPlanner {}; diff --git a/datafusion/functions-nested/src/macros.rs b/datafusion/functions-nested/src/macros.rs index 00247f39ac10f..b561e4ae76aa8 100644 --- a/datafusion/functions-nested/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -49,7 +49,7 @@ macro_rules! make_udf_expr_and_func { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + datafusion_expr::Expr::scalar_function(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), vec![$($arg),*], )) @@ -62,7 +62,7 @@ macro_rules! make_udf_expr_and_func { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + datafusion_expr::Expr::scalar_function(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), arg, )) diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 9f6a5031ac4e6..a5e4d0094f9ee 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -38,7 +38,7 @@ use crate::make_array::make_array; pub fn map(keys: Vec, values: Vec) -> Expr { let keys = make_array(keys); let values = make_array(values); - Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values])) + Expr::scalar_function(ScalarFunction::new_udf(map_udf(), vec![keys, values])) } create_func!(MapFunc, map_udf); diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 1929b8222a1b6..a390998f60fc7 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -111,14 +111,14 @@ impl ExprPlanner for NestedFunctionPlanner { let keys = make_array(keys.into_iter().map(|(_, e)| e).collect()); let values = make_array(values.into_iter().map(|(_, e)| e).collect()); - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } fn plan_any(&self, expr: RawBinaryExpr) -> Result> { if expr.op == sqlparser::ast::BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf( array_has_udf(), // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` @@ -150,8 +150,8 @@ impl ExprPlanner for FieldAccessPlanner { GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Planned(Expr::AggregateFunction( + Expr::AggregateFunction(agg_func, _) if is_array_agg(&agg_func) => { + Ok(PlannerResult::Planned(Expr::aggregate_function( datafusion_expr::expr::AggregateFunction::new_udf( nth_value_udaf(), agg_func diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 3e727a1765388..5fe9077434450 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -125,7 +125,7 @@ impl ScalarUDFImpl for ArrowCastFunc { arg } else { // Use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { + Expr::cast(datafusion_expr::Cast { expr: Box::new(arg), data_type: target_type, }) @@ -172,7 +172,7 @@ fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { + let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = &args[1] else { return plan_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", &args[1] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 94ab56ce1ed82..d257e05e0f056 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -67,7 +67,7 @@ impl ScalarUDFImpl for GetFieldFunc { } let name = match &args[1] { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, _ => { return exec_err!( "get_field function requires the argument field_name to be a string" @@ -87,7 +87,7 @@ impl ScalarUDFImpl for GetFieldFunc { } let name = match &args[1] { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, _ => { return exec_err!( "get_field function requires the argument field_name to be a string" @@ -120,7 +120,7 @@ impl ScalarUDFImpl for GetFieldFunc { } let name = match &args[1] { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, _ => { return exec_err!( "get_field function requires the argument field_name to be a string" diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 963b70514838e..1f5a35cfd97d8 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -148,7 +148,7 @@ impl ScalarUDFImpl for NamedStructFunc { let name = &chunk[0]; let value = &chunk[1]; - if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { + if let Expr::Literal(ScalarValue::Utf8(Some(name)), _) = name { Ok(Field::new(name, value.get_type(schema)?, true)) } else { exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 717a74797c0b5..faf4b3c21551a 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -46,7 +46,7 @@ impl ExprPlanner for CoreFunctionPlanner { args: Vec, is_named_struct: bool, ) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf( if is_named_struct { named_struct() @@ -59,7 +59,7 @@ impl ExprPlanner for CoreFunctionPlanner { } fn plan_overlay(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::string::overlay(), args), ))) } @@ -70,7 +70,7 @@ impl ExprPlanner for CoreFunctionPlanner { qualifier: Option<&TableReference>, nested_names: &[String], ) -> Result>> { - let col = Expr::Column(Column::from((qualifier, field))); + let col = Expr::column(Column::from((qualifier, field))); // Start with the base column expression let mut expr = col; @@ -78,7 +78,7 @@ impl ExprPlanner for CoreFunctionPlanner { // Iterate over nested_names and create nested get_field expressions for nested_name in nested_names { let get_field_args = vec![expr, lit(ScalarValue::from(nested_name.clone()))]; - expr = Expr::ScalarFunction(ScalarFunction::new_udf( + expr = Expr::scalar_function(ScalarFunction::new_udf( crate::core::get_field(), get_field_args, )); diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index d4c8be366d063..9117420633b06 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -95,7 +95,7 @@ impl ScalarUDFImpl for CurrentDateFunc { .unwrap() .num_days_from_ce(), ); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::literal( ScalarValue::Date32(days), ))) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 102262d0e1957..c1556cdebbb89 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -83,7 +83,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { ) -> Result { let now_ts = info.execution_props().query_execution_start_time; let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::literal( ScalarValue::Time64Nanosecond(nano), ))) } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index eb221df9827e4..e44509b759801 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -158,7 +158,7 @@ impl ScalarUDFImpl for DatePartFunc { _arg_types: &[DataType], ) -> Result { match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(part))) if is_epoch(part) => { + Expr::Literal(ScalarValue::Utf8(Some(part)), _) if is_epoch(part) => { Ok(DataType::Float64) } _ => Ok(DataType::Int32), diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 5c6363b23741c..5832251bd94fa 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -72,7 +72,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { match arg_types.len() { 1 => Ok(Timestamp(Second, None)), 2 => match &args[1] { - Expr::Literal(ScalarValue::Utf8(Some(tz))) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), + Expr::Literal(ScalarValue::Utf8(Some(tz)), _) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), _ => exec_err!( "Second argument for `from_unixtime` must be non-null utf8, received {:?}", arg_types[1]), diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index c2d42ec1bf784..254e7e52f307f 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -85,7 +85,7 @@ impl ScalarUDFImpl for NowFunc { .execution_props() .query_execution_start_time .timestamp_nanos_opt(); - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Ok(ExprSimplifyResult::Simplified(Expr::literal( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 438fb34d15521..fededea3f538c 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -216,12 +216,14 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args }, _) if is_pow(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index cd45d10e8e8ef..09bdff23f4863 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -147,15 +147,17 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( + Expr::Literal(value, _) + if value == ScalarValue::new_zero(&exponent_type)? => + { + Ok(ExprSimplifyResult::Simplified(Expr::literal( ScalarValue::new_one(&info.get_data_type(&base)?)?, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args }, _) if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index 93edec7ece307..31a1c2d10b0f9 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -30,21 +30,21 @@ pub struct UserDefinedFunctionPlanner; impl ExprPlanner for UserDefinedFunctionPlanner { #[cfg(feature = "datetime_expressions")] fn plan_extract(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::datetime::date_part(), args), ))) } #[cfg(feature = "unicode_expressions")] fn plan_position(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::unicode::strpos(), args), ))) } #[cfg(feature = "unicode_expressions")] fn plan_substring(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( + Ok(PlannerResult::Planned(Expr::scalar_function( ScalarFunction::new_udf(crate::unicode::substr(), args), ))) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index d55d310bf3a08..4d8d65656c344 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -296,7 +296,7 @@ pub fn simplify_concat(args: Vec) -> Result { let data_types: Vec<_> = args .iter() .filter_map(|expr| match expr { - Expr::Literal(l) => Some(l.data_type()), + Expr::Literal(l, _) => Some(l.data_type()), _ => None, }) .collect(); @@ -305,25 +305,25 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { - Expr::Literal(ScalarValue::Utf8(None)) => {} - Expr::Literal(ScalarValue::LargeUtf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => {} + Expr::Literal(ScalarValue::LargeUtf8(None), _) => { } - Expr::Literal(ScalarValue::Utf8View(None)) => { } + Expr::Literal(ScalarValue::Utf8View(None), _) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(x) => { + Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) @@ -360,7 +360,7 @@ pub fn simplify_concat(args: Vec) -> Result { } if !args.eq(&new_args) { - Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + Ok(ExprSimplifyResult::Simplified(Expr::scalar_function( ScalarFunction { func: concat(), args: new_args, diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 9118b335bda2b..4efb1e434a239 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -308,6 +308,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { // when the delimiter is an empty string, @@ -320,8 +321,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -331,7 +332,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -348,7 +349,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( + None => Ok(ExprSimplifyResult::Simplified(Expr::literal( ScalarValue::Utf8(None), ))), } } - Expr::Literal(d) => internal_err!( + Expr::Literal(d, _) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), _ => { @@ -378,7 +379,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 34e35c66107a5..b199b89008961 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -42,6 +42,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +enumset = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 454afa24b628c..0727e9c8cb98e 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -72,17 +72,17 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); let transformed_expr = expr.transform_up(|expr| match expr { - Expr::WindowFunction(mut window_function) + Expr::WindowFunction(mut window_function, _) if is_count_star_window_aggregate(&window_function) => { window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::WindowFunction(window_function))) + Ok(Transformed::yes(Expr::window_function(window_function))) } - Expr::AggregateFunction(mut aggregate_function) + Expr::AggregateFunction(mut aggregate_function, _) if is_count_star_aggregate(&aggregate_function) => { aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::AggregateFunction( + Ok(Transformed::yes(Expr::aggregate_function( aggregate_function, ))) } @@ -219,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(WindowFunction::new( + .window(vec![Expr::window_function(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index ff9f3df39fd20..1ac32258d0db8 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -53,25 +53,25 @@ impl AnalyzerRule for ExpandWildcardRule { fn expand_internal(plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { let projected_expr = expand_exprlist(&input, expr)?; validate_unique_names("Projections", projected_expr.iter())?; Ok(Transformed::yes( Projection::try_new(projected_expr, Arc::clone(&input)) - .map(LogicalPlan::Projection)?, + .map(LogicalPlan::projection)?, )) } // The schema of the plan should also be updated if the child plan is transformed. - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { Ok(Transformed::yes( - SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias)?, + SubqueryAlias::try_new(input, alias).map(LogicalPlan::subquery_alias)?, )) } - LogicalPlan::Distinct(Distinct::On(distinct_on)) => { + LogicalPlan::Distinct(Distinct::On(distinct_on), _) => { let projected_expr = expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; validate_unique_names("Distinct", projected_expr.iter())?; - Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::On( + Ok(Transformed::yes(LogicalPlan::distinct(Distinct::On( DistinctOn::try_new( distinct_on.on_expr, projected_expr, @@ -89,7 +89,7 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec) -> Result> { let input = find_base_plan(input); for e in expr { match e { - Expr::Wildcard(Wildcard { qualifier, options }) => { + Expr::Wildcard(Wildcard { qualifier, options }, _) => { if let Some(qualifier) = qualifier { let expanded = expand_qualified_wildcard( &qualifier, @@ -120,10 +120,13 @@ fn expand_exprlist(input: &LogicalPlan, expr: Vec) -> Result> { // A workaround to handle the case when the column name is "*". // We transform the expression to a Expr::Column through [Column::from_name] in many places. // It would also convert the wildcard expression to a column expression with name "*". - Expr::Column(Column { - ref relation, - ref name, - }) => { + Expr::Column( + Column { + ref relation, + ref name, + }, + _, + ) => { if name.eq("*") { if let Some(qualifier) = relation { projected_expr.extend(expand_qualified_wildcard( @@ -157,7 +160,7 @@ fn replace_columns( replace: &PlannedReplaceSelectItem, ) -> Result> { for expr in exprs.iter_mut() { - if let Expr::Column(Column { name, .. }) = expr { + if let Expr::Column(Column { name, .. }, _) = expr { if let Some((_, new_expr)) = replace .items() .iter() diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e3..d7ef7cfab0ed9 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -50,7 +50,7 @@ impl ApplyFunctionRewrites { // resolution only, so order does not matter here let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = &plan { + if let LogicalPlan::TableScan(ts, _) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 68edda671a7a7..0a1177b4c6bcb 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -57,7 +57,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added // during the early stage of planning. - LogicalPlan::TableScan(table_scan) if table_scan.filters.is_empty() => { + LogicalPlan::TableScan(table_scan, _) if table_scan.filters.is_empty() => { if let Some(sub_plan) = table_scan.source.get_logical_plan() { let sub_plan = sub_plan.into_owned(); let projection_exprs = @@ -71,7 +71,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .build() .map(Transformed::yes) } else { - Ok(Transformed::no(LogicalPlan::TableScan(table_scan))) + Ok(Transformed::no(LogicalPlan::table_scan(table_scan))) } } _ => Ok(Transformed::no(plan)), @@ -88,12 +88,12 @@ fn generate_projection_expr( let mut exprs = vec![]; if let Some(projection) = projection { for i in projection { - exprs.push(Expr::Column(Column::from( + exprs.push(Expr::column(Column::from( sub_plan.schema().qualified_field(*i), ))); } } else { - exprs.push(Expr::Wildcard(Wildcard { + exprs.push(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })); diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index a9fd4900b2f4a..afad2fae2ca81 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -181,9 +181,9 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { // recursively look for subqueries expr.apply(|expr| { match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { + Expr::Exists(Exists { subquery, .. }, _) + | Expr::InSubquery(InSubquery { subquery, .. }, _) + | Expr::ScalarSubquery(subquery, _) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } _ => {} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 16ebb8cd3972f..bea97a2463590 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, }; -use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ @@ -79,7 +79,7 @@ fn replace_grouping_exprs( aggr_expr: Vec, ) -> Result { // Create HashMap from Expr to index in the grouping_id bitmap - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_, _)]); let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; let columns = schema.columns(); let mut new_agg_expr = Vec::new(); @@ -90,36 +90,33 @@ fn replace_grouping_exprs( columns .iter() .take(group_expr_len) - .map(|column| Expr::Column(column.clone())), + .map(|column| Expr::column(column.clone())), ); for (expr, column) in aggr_expr .into_iter() .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) { match expr { - Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + Expr::AggregateFunction(ref function, _) if is_grouping_function(&expr) => { let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, is_grouping_set, )?; - projection_exprs.push(Expr::Alias(Alias::new( - grouping_expr, - column.relation, - column.name, - ))); + projection_exprs + .push(grouping_expr.alias_qualified(column.relation, column.name)); } _ => { - projection_exprs.push(Expr::Column(column)); + projection_exprs.push(Expr::column(column)); new_agg_expr.push(expr); } } } // Recreate aggregate without grouping functions let new_aggregate = - LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + LogicalPlan::aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); // Create projection with grouping functions calculations - let projection = LogicalPlan::Projection(Projection::try_new( + let projection = LogicalPlan::projection(Projection::try_new( projection_exprs, new_aggregate.into(), )?); @@ -132,13 +129,16 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; let transformed_plan = transformed_plan.transform_data(|plan| match plan { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - .. - }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + LogicalPlan::Aggregate( + Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }, + _, + ) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, )), _ => Ok(Transformed::no(plan)), @@ -150,7 +150,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { fn is_grouping_function(expr: &Expr) -> bool { // TODO: Do something better than name here should grouping be a built // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }, _) if func.name() == "grouping") } fn contains_grouping_function(exprs: &[Expr]) -> bool { @@ -188,23 +188,23 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32))); + return Ok(Expr::literal(ScalarValue::from(0i32))); } let group_by_expr_count = group_by_expr.len(); let literal = |value: usize| { if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) + Expr::literal(ScalarValue::from(value as u8)) } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) + Expr::literal(ScalarValue::from(value as u16)) } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) + Expr::literal(ScalarValue::from(value as u32)) } else { - Expr::Literal(ScalarValue::from(value as u64)) + Expr::literal(ScalarValue::from(value as u64)) } }; - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + let grouping_id_column = Expr::column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); // The grouping call is exactly our internal grouping id if args.len() == group_by_expr_count && args diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index fee06eeb9f75f..2ff148e559872 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -38,7 +38,7 @@ pub fn check_subquery_expr( expr: &Expr, ) -> Result<()> { check_plan(inner_plan)?; - if let Expr::ScalarSubquery(subquery) = expr { + if let Expr::ScalarSubquery(subquery, _) = expr { // Scalar subquery should only return one column if subquery.subquery.schema().fields().len() > 1 { return plan_err!( @@ -50,13 +50,13 @@ pub fn check_subquery_expr( // Correlated scalar subquery must be aggregated to return at most one row if !subquery.outer_ref_columns.is_empty() { match strip_inner_query(inner_plan) { - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { check_aggregation_in_scalar_subquery(inner_plan, agg) } - LogicalPlan::Filter(Filter { input, .. }) - if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) => + LogicalPlan::Filter(Filter { input, .. }, _) + if matches!(input.as_ref(), LogicalPlan::Aggregate(_, _)) => { - if let LogicalPlan::Aggregate(agg) = input.as_ref() { + if let LogicalPlan::Aggregate(agg, _) = input.as_ref() { check_aggregation_in_scalar_subquery(inner_plan, agg) } else { Ok(()) @@ -77,9 +77,9 @@ pub fn check_subquery_expr( } }?; match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) => Ok(()), + LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}, _) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( @@ -96,7 +96,7 @@ pub fn check_subquery_expr( } check_correlations_in_subquery(inner_plan) } else { - if let Expr::InSubquery(subquery) = expr { + if let Expr::InSubquery(subquery, _) = expr { // InSubquery should only return one column if subquery.subquery.subquery.schema().fields().len() > 1 { return plan_err!( @@ -107,11 +107,11 @@ pub fn check_subquery_expr( } } match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Join(_) => Ok(()), + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Join(_, _) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ Projection, Filter, Window functions, Aggregate and Join plan nodes, \ @@ -136,17 +136,17 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re } // We want to support as many operators as possible inside the correlated subquery match inner_plan { - LogicalPlan::Aggregate(_) => { + LogicalPlan::Aggregate(_, _) => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { input, .. }) => { + LogicalPlan::Filter(Filter { input, .. }, _) => { check_inner_plan(input, can_contain_outer_ref) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; @@ -154,29 +154,32 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re })?; Ok(()) } - LogicalPlan::Projection(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Values(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Unnest(_) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Distinct(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Unnest(_, _) => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Join(Join { - left, - right, - join_type, - .. - }) => match join_type { + LogicalPlan::Join( + Join { + left, + right, + join_type, + .. + }, + _, + ) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { check_inner_plan(plan, can_contain_outer_ref)?; @@ -203,7 +206,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Re Ok(()) } }, - LogicalPlan::Extension(_) => Ok(()), + LogicalPlan::Extension(_, _) => Ok(()), _ => plan_err!("Unsupported operator in the subquery plan."), } } @@ -241,10 +244,10 @@ fn check_aggregation_in_scalar_subquery( fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { match inner_plan { - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { strip_inner_query(projection.input.as_ref()) } - LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()), + LogicalPlan::SubqueryAlias(alias, _) => strip_inner_query(alias.input.as_ref()), other => other, } } @@ -252,7 +255,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; inner_plan.apply_with_subqueries(|plan| { - if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { + if let LogicalPlan::Filter(Filter { predicate, .. }, _) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() .partition(|e| e.contains_outer()); @@ -340,7 +343,7 @@ mod test { #[test] fn wont_fail_extension_plan() { - let plan = LogicalPlan::Extension(Extension { + let plan = LogicalPlan::extension(Extension { node: Arc::new(MockUserDefinedLogicalPlan { empty_schema: DFSchemaRef::new(DFSchema::empty()), }), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b56c2dc604a9b..26f05efca7494 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -114,7 +114,7 @@ fn analyze_internal( // resolution only, so order does not matter here let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = &plan { + if let LogicalPlan::TableScan(ts, _) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -128,9 +128,9 @@ fn analyze_internal( schema.merge(external_schema); // Coerce filter predicates to boolean (handles `WHERE NULL`) - let plan = if let LogicalPlan::Filter(mut filter) = plan { + let plan = if let LogicalPlan::Filter(mut filter, _) = plan { filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; - LogicalPlan::Filter(filter) + LogicalPlan::filter(filter) } else { plan }; @@ -168,9 +168,9 @@ impl<'a> TypeCoercionRewriter<'a> { /// for type-coercion approach. pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result { match plan { - LogicalPlan::Join(join) => self.coerce_join(join), - LogicalPlan::Union(union) => Self::coerce_union(union), - LogicalPlan::Limit(limit) => Self::coerce_limit(limit), + LogicalPlan::Join(join, _) => self.coerce_join(join), + LogicalPlan::Union(union, _) => Self::coerce_union(union), + LogicalPlan::Limit(limit, _) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -201,7 +201,7 @@ impl<'a> TypeCoercionRewriter<'a> { .map(|expr| self.coerce_join_filter(expr)) .transpose()?; - Ok(LogicalPlan::Join(join)) + Ok(LogicalPlan::join(join)) } /// Coerce the union’s inputs to a common schema compatible with all inputs. @@ -215,7 +215,7 @@ impl<'a> TypeCoercionRewriter<'a> { let plan = coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?; match plan { - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { Ok(Arc::new(project_with_column_index( expr, input, @@ -226,7 +226,7 @@ impl<'a> TypeCoercionRewriter<'a> { } }) .collect::>>()?; - Ok(LogicalPlan::Union(Union { + Ok(LogicalPlan::union(Union { inputs: new_inputs, schema: union_schema, })) @@ -256,7 +256,7 @@ impl<'a> TypeCoercionRewriter<'a> { .skip .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) .transpose()?; - Ok(LogicalPlan::Limit(Limit { + Ok(LogicalPlan::limit(Limit { input: limit.input, fetch: new_fetch.map(Box::new), skip: new_skip.map(Box::new), @@ -295,27 +295,30 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { fn f_up(&mut self, expr: Expr) -> Result> { match expr { - Expr::Unnest(_) => not_impl_err!( + Expr::Unnest(_, _) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" ), - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { + Expr::ScalarSubquery( + Subquery { + subquery, + outer_ref_columns, + }, + _, + ) => { let new_plan = analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; - Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::scalar_subquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { let new_plan = analyze_internal( self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - Ok(Transformed::yes(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, @@ -323,11 +326,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { negated, }))) } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => { let new_plan = analyze_internal( self.schema, Arc::unwrap_or_clone(subquery.subquery), @@ -343,41 +349,44 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::in_subquery(InSubquery::new( Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } - Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + Expr::Not(expr, _) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, )?))), - Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + Expr::IsTrue(expr, _) => Ok(Transformed::yes(is_true( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + Expr::IsNotTrue(expr, _) => Ok(Transformed::yes(is_not_true( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + Expr::IsFalse(expr, _) => Ok(Transformed::yes(is_false( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + Expr::IsNotFalse(expr, _) => Ok(Transformed::yes(is_not_false( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + Expr::IsUnknown(expr, _) => Ok(Transformed::yes(is_unknown( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + Expr::IsNotUnknown(expr, _) => Ok(Transformed::yes(is_not_unknown( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { let left_type = expr.get_type(self.schema)?; let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { @@ -395,7 +404,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), }; let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); - Ok(Transformed::yes(Expr::Like(Like::new( + Ok(Transformed::yes(Expr::_like(Like::new( negated, expr, pattern, @@ -403,20 +412,23 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { case_insensitive, )))) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let (left, right) = self.coerce_binary_op(*left, op, *right)?; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr::new( Box::new(left), op, Box::new(right), )))) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let expr_type = expr.get_type(self.schema)?; let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) @@ -439,18 +451,21 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - Ok(Transformed::yes(Expr::Between(Between::new( + Ok(Transformed::yes(Expr::_between(Between::new( Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, Box::new(low.cast_to(&coercion_type, self.schema)?), Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() @@ -471,7 +486,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; - Ok(Transformed::yes(Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::_in_list(InList ::new( Box::new(cast_expr), cast_list_expr, negated, @@ -479,34 +494,37 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { } } } - Expr::Case(case) => { + Expr::Case(case, _) => { let case = coerce_case_expression(case, self.schema)?; - Ok(Transformed::yes(Expr::Case(case))) + Ok(Transformed::yes(Expr::case(case))) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, self.schema, &func, )?; - Ok(Transformed::yes(Expr::ScalarFunction( + Ok(Transformed::yes(Expr::scalar_function( ScalarFunction::new_udf(func, new_expr), ))) } - Expr::AggregateFunction(expr::AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { + Expr::AggregateFunction( + expr::AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment, + }, + _, + ) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, self.schema, &func, )?; - Ok(Transformed::yes(Expr::AggregateFunction( + Ok(Transformed::yes(Expr::aggregate_function( expr::AggregateFunction::new_udf( func, new_expr, @@ -517,14 +535,17 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + _, + ) => { let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -540,7 +561,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::window_function(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -548,20 +569,20 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { .build()?, )) } - Expr::Alias(_) - | Expr::Column(_) - | Expr::ScalarVariable(_, _) - | Expr::Literal(_) - | Expr::SimilarTo(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::Negative(_) - | Expr::Cast(_) - | Expr::TryCast(_) + Expr::Alias(_, _) + | Expr::Column(_, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Literal(_, _) + | Expr::SimilarTo(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::Negative(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) - | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::GroupingSet(_, _) + | Expr::Placeholder(_, _) + | Expr::OuterReferenceColumn(_, _, _) => Ok(Transformed::no(expr)), } } } @@ -993,20 +1014,23 @@ fn project_with_column_index( .into_iter() .enumerate() .map(|(i, e)| match e { - Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { + Expr::Alias(Alias { ref name, .. }, _) if name != schema.field(i).name() => { e.unalias().alias(schema.field(i).name()) } - Expr::Column(Column { - relation: _, - ref name, - }) if name != schema.field(i).name() => e.alias(schema.field(i).name()), + Expr::Column( + Column { + relation: _, + ref name, + }, + _, + ) if name != schema.field(i).name() => e.alias(schema.field(i).name()), Expr::Alias { .. } | Expr::Column { .. } => e, _ => e.alias(schema.field(i).name()), }) .collect::>(); Projection::try_new_with_schema(alias_expr, input, schema) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) } #[cfg(test)] @@ -1037,14 +1061,14 @@ mod test { use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; fn empty() -> Arc { - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + Arc::new(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), })) } fn empty_with_type(data_type: DataType) -> Arc { - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + Arc::new(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new( DFSchema::from_unqualified_fields( @@ -1060,7 +1084,7 @@ mod test { fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1092,7 +1116,7 @@ mod test { // scenario: outermost utf8view projection let expr = col("a"); let empty = empty_with_type(DataType::Utf8View); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone()], Arc::clone(&empty), )?); @@ -1106,7 +1130,7 @@ mod test { // Plan B // scenario: outermost bool projection let bool_expr = col("a").lt(lit("foo")); - let bool_plan = LogicalPlan::Projection(Projection::try_new( + let bool_plan = LogicalPlan::projection(Projection::try_new( vec![bool_expr], Arc::clone(&empty), )?); @@ -1121,7 +1145,7 @@ mod test { // Plan C // scenario: with a non-projection root logical plan node let sort_expr = expr.sort(true, true); - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: vec![sort_expr], input: Arc::new(plan), fetch: None, @@ -1136,7 +1160,7 @@ mod test { // Plan D // scenario: two layers of projections with view types - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![col("a")], Arc::new(sort_plan), )?); @@ -1156,7 +1180,7 @@ mod test { // scenario: outermost binaryview projection let expr = col("a"); let empty = empty_with_type(DataType::BinaryView); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone()], Arc::clone(&empty), )?); @@ -1170,7 +1194,7 @@ mod test { // Plan B // scenario: outermost bool projection let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1])); - let bool_plan = LogicalPlan::Projection(Projection::try_new( + let bool_plan = LogicalPlan::projection(Projection::try_new( vec![bool_expr], Arc::clone(&empty), )?); @@ -1185,7 +1209,7 @@ mod test { // Plan C // scenario: with a non-projection root logical plan node let sort_expr = expr.sort(true, true); - let sort_plan = LogicalPlan::Sort(Sort { + let sort_plan = LogicalPlan::sort(Sort { expr: vec![sort_expr], input: Arc::new(plan), fetch: None, @@ -1200,7 +1224,7 @@ mod test { // Plan D // scenario: two layers of projections with view types - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![col("a")], Arc::new(sort_plan), )?); @@ -1219,7 +1243,7 @@ mod test { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr.clone().or(expr)], empty, )?); @@ -1263,7 +1287,7 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }) .call(vec![lit(123_i32)]); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) @@ -1291,8 +1315,8 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }); let scalar_function_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); - let plan = LogicalPlan::Projection(Projection::try_new( + Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); + let plan = LogicalPlan::projection(Projection::try_new( vec![scalar_function_expr], empty, )?); @@ -1312,7 +1336,7 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let udaf = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], false, @@ -1320,7 +1344,7 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1341,7 +1365,7 @@ mod test { Field::new("avg", DataType::Float64, true), ], )); - let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let udaf = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], false, @@ -1360,7 +1384,7 @@ mod test { #[test] fn agg_function_case() -> Result<()> { let empty = empty(); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![lit(12f64)], false, @@ -1368,12 +1392,12 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: avg(Float64(12))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int32); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![cast(col("a"), DataType::Float64)], false, @@ -1381,7 +1405,7 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) @@ -1390,7 +1414,7 @@ mod test { #[test] fn agg_function_invalid_input_avg() -> Result<()> { let empty = empty(); - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let agg_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( avg_udaf(), vec![lit("1")], false, @@ -1412,7 +1436,7 @@ mod test { let expr = cast(lit("1998-03-18"), DataType::Date32) + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1424,20 +1448,20 @@ mod test { // a in (1,4,8), a is int64 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); - let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + let empty = Arc::new(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(), std::collections::HashMap::new(), )?), })); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } @@ -1451,7 +1475,7 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), ); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + let plan = LogicalPlan::filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ \n EmptyRelation"; @@ -1467,7 +1491,7 @@ mod test { lit("2002-12-08"), ); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + let plan = LogicalPlan::filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ @@ -1481,12 +1505,12 @@ mod test { let expr = col("a").is_true(); let empty = empty_with_type(DataType::Boolean); let plan = - LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); + LogicalPlan::projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); @@ -1494,21 +1518,21 @@ mod test { // is not true let expr = col("a").is_not_true(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1520,25 +1544,25 @@ mod test { // like : utf8 like "abc" let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); + let like_expr = Expr::_like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![like_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( @@ -1548,25 +1572,25 @@ mod test { // ilike let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); + let ilike_expr = Expr::_like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); - let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![ilike_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( @@ -1581,12 +1605,12 @@ mod test { let expr = col("a").is_unknown(); let empty = empty_with_type(DataType::Boolean); let plan = - LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); + LogicalPlan::projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(Utf8); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); @@ -1594,7 +1618,7 @@ mod test { // is not unknown let expr = col("a").is_not_unknown(); let empty = empty_with_type(DataType::Boolean); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1612,7 +1636,7 @@ mod test { signature: Signature::variadic(vec![Utf8], Volatility::Immutable), }) .call(args.to_vec()); - let plan = LogicalPlan::Projection(Projection::try_new( + let plan = LogicalPlan::projection(Projection::try_new( vec![expr], Arc::clone(&empty), )?); @@ -1670,7 +1694,7 @@ mod test { ) .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -2007,7 +2031,7 @@ mod test { #[test] fn interval_plus_timestamp() -> Result<()> { // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp; - let expr = Expr::BinaryExpr(BinaryExpr::new( + let expr = Expr::binary_expr(BinaryExpr::new( Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))), Operator::Plus, Box::new(cast( @@ -2016,7 +2040,7 @@ mod test { )), )); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) @@ -2024,7 +2048,7 @@ mod test { #[test] fn timestamp_subtract_timestamp() -> Result<()> { - let expr = Expr::BinaryExpr(BinaryExpr::new( + let expr = Expr::binary_expr(BinaryExpr::new( Box::new(cast( lit("1998-03-18"), DataType::Timestamp(TimeUnit::Nanosecond, None), @@ -2036,7 +2060,7 @@ mod test { )), )); let empty = empty(); - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let plan = LogicalPlan::projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -2048,7 +2072,7 @@ mod test { let empty_int32 = empty_with_type(DataType::Int32); let empty_int64 = empty_with_type(DataType::Int64); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_int32, @@ -2056,7 +2080,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_int64)?); // add cast for subquery let expected = "\ Filter: a IN ()\ @@ -2073,7 +2097,7 @@ mod test { let empty_int32 = empty_with_type(DataType::Int32); let empty_int64 = empty_with_type(DataType::Int64); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_int64, @@ -2081,7 +2105,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_int32)?); // add cast for subquery let expected = "\ Filter: CAST(a AS Int64) IN ()\ @@ -2097,7 +2121,7 @@ mod test { let empty_inside = empty_with_type(DataType::Decimal128(10, 5)); let empty_outside = empty_with_type(DataType::Decimal128(8, 8)); - let in_subquery_expr = Expr::InSubquery(InSubquery::new( + let in_subquery_expr = Expr::in_subquery(InSubquery::new( Box::new(col("a")), Subquery { subquery: empty_inside, @@ -2105,7 +2129,7 @@ mod test { }, false, )); - let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?); + let plan = LogicalPlan::filter(Filter::try_new(in_subquery_expr, empty_outside)?); // add cast for subquery let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN ()\ \n Subquery:\ diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 16a4fa6be38d0..ecf49ac80943d 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -88,7 +88,7 @@ impl CommonSubexprEliminate { self.try_unary_plan(expr, input, config)? .map_data(|(new_expr, new_input)| { Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) }) } @@ -106,7 +106,7 @@ impl CommonSubexprEliminate { let new_sort = self .try_unary_plan(sort_expressions, input, config)? .update_data(|(new_expr, new_input)| { - LogicalPlan::Sort(Sort { + LogicalPlan::sort(Sort { expr: new_expr .into_iter() .zip(sort_params) @@ -138,7 +138,7 @@ impl CommonSubexprEliminate { assert_eq!(new_expr.len(), 1); // passed in vec![predicate] let new_predicate = new_expr.pop().unwrap(); Filter::try_new(new_predicate, Arc::new(new_input)) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) }) } @@ -213,7 +213,7 @@ impl CommonSubexprEliminate { }) .collect::>(); Window::try_new(new_window_expr, Arc::new(plan)) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) }, ) } else { @@ -226,7 +226,7 @@ impl CommonSubexprEliminate { Arc::new(plan), schema, ) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) }) } }) @@ -331,12 +331,12 @@ impl CommonSubexprEliminate { rewritten_aggr_expr.into_iter().zip(new_aggr_expr) { if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = + if let Expr::Alias(Alias { expr, name, .. }, _) = expr_rewritten { agg_exprs.push(expr.alias(&name)); proj_exprs - .push(Expr::Column(Column::from_name(name))); + .push(Expr::column(Column::from_name(name))); } else { let expr_alias = config.alias_generator().next(CSE_PREFIX); @@ -347,7 +347,7 @@ impl CommonSubexprEliminate { agg_exprs.push(expr_rewritten.alias(&expr_alias)); proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)) + Expr::column(Column::from_name(expr_alias)) .alias(out_name), ); } @@ -356,13 +356,13 @@ impl CommonSubexprEliminate { } } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( + let agg = LogicalPlan::aggregate(Aggregate::try_new( new_input, new_group_expr, agg_exprs, )?); Projection::try_new(proj_exprs, Arc::new(agg)) - .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + .map(|p| Transformed::yes(LogicalPlan::projection(p))) } // If there aren't any common aggregate sub-expressions, then just @@ -399,7 +399,7 @@ impl CommonSubexprEliminate { // Since `group_expr` may have changed, schema may also. // Use `try_new()` method. Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Transformed::no) } else { Aggregate::try_new_with_schema( @@ -408,7 +408,7 @@ impl CommonSubexprEliminate { rewritten_aggr_expr, schema, ) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) .map(Transformed::no) } } @@ -505,12 +505,15 @@ fn get_consecutive_window_exprs( ) -> (Vec>, Vec, LogicalPlan) { let mut window_expr_list = vec![]; let mut window_schemas = vec![]; - let mut plan = LogicalPlan::Window(window); - while let LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) = plan + let mut plan = LogicalPlan::window(window); + while let LogicalPlan::Window( + Window { + input, + window_expr, + schema, + }, + _, + ) = plan { window_expr_list.push(window_expr); window_schemas.push(schema); @@ -541,31 +544,31 @@ impl OptimizerRule for CommonSubexprEliminate { let original_schema = Arc::clone(plan.schema()); let optimized_plan = match plan { - LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?, - LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?, - LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?, - LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, - LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, - LogicalPlan::Join(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Values(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Statement(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + LogicalPlan::Projection(proj, _) => self.try_optimize_proj(proj, config)?, + LogicalPlan::Sort(sort, _) => self.try_optimize_sort(sort, config)?, + LogicalPlan::Filter(filter, _) => self.try_optimize_filter(filter, config)?, + LogicalPlan::Window(window, _) => self.try_optimize_window(window, config)?, + LogicalPlan::Aggregate(agg, _) => self.try_optimize_aggregate(agg, config)?, + LogicalPlan::Join(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::DescribeTable(_, _) + | LogicalPlan::Distinct(_, _) + | LogicalPlan::Extension(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::Unnest(_, _) + | LogicalPlan::RecursiveQuery(_, _) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? @@ -631,7 +634,7 @@ impl CSEController for ExprCSEController<'_> { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. - Expr::ScalarFunction(ScalarFunction { func, args }) + Expr::ScalarFunction(ScalarFunction { func, args }, _) if func.short_circuits() => { Some((vec![], args.iter().collect())) @@ -639,20 +642,26 @@ impl CSEController for ExprCSEController<'_> { // In case of `And` and `Or` the first child is surely executed, but we // account subexpressions as conditional in the second. - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::And | Operator::Or, - right, - }) => Some((vec![left.as_ref()], vec![right.as_ref()])), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }, + _, + ) => Some((vec![left.as_ref()], vec![right.as_ref()])), // In case of `Case` the optional base expression and the first when // expressions are surely executed, but we account subexpressions as // conditional in the others. - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => Some(( + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => Some(( expr.iter() .map(|e| e.as_ref()) .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) @@ -711,12 +720,12 @@ impl CSEController for ExprCSEController<'_> { } fn rewrite_f_down(&mut self, node: &Expr) { - if matches!(node, Expr::Alias(_)) { + if matches!(node, Expr::Alias(_, _)) { self.alias_counter += 1; } } fn rewrite_f_up(&mut self, node: &Expr) { - if matches!(node, Expr::Alias(_)) { + if matches!(node, Expr::Alias(_, _)) { self.alias_counter -= 1 } } @@ -757,7 +766,7 @@ fn build_common_expr_project_plan( } } - Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::projection) } /// Build the projection plan to eliminate unnecessary columns produced by @@ -770,20 +779,20 @@ fn build_recover_project_plan( input: LogicalPlan, ) -> Result { let col_exprs = schema.iter().map(Expr::from).collect(); - Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::projection) } fn extract_expressions(expr: &Expr, result: &mut Vec) { - if let Expr::GroupingSet(groupings) = expr { + if let Expr::GroupingSet(groupings, _) = expr { for e in groupings.distinct_expr() { let (qualifier, field_name) = e.qualified_name(); let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)) + result.push(Expr::column(col)) } } else { let (qualifier, field_name) = expr.qualified_name(); let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)); + result.push(Expr::column(col)); } } @@ -878,7 +887,7 @@ mod test { let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::aggregate_function(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "my_agg", Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -1004,7 +1013,7 @@ mod test { let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]); let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?; - let col_a = Expr::Column(Column::new(Some("table.test"), "col.a")); + let col_a = Expr::column(Column::new(Some("table.test"), "col.a")); let plan = LogicalPlanBuilder::from(table_scan) .aggregate( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b5726d9991379..072448407f3cf 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -123,8 +123,10 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), - LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { + LogicalPlan::Filter(_, _) => Ok(Transformed::no(plan)), + LogicalPlan::Union(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Extension(_, _) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case @@ -134,7 +136,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { Ok(Transformed::no(plan)) } } - LogicalPlan::Limit(_) => { + LogicalPlan::Limit(_, _) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); match (self.exists_sub_query, plan_hold_outer) { (false, true) => { @@ -157,7 +159,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema(); match &plan { - LogicalPlan::Filter(plan_filter) => { + LogicalPlan::Filter(plan_filter, _) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); self.can_pull_over_aggregation = self.can_pull_over_aggregation && subquery_filter_exprs @@ -224,7 +226,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } } } - LogicalPlan::Projection(projection) + LogicalPlan::Projection(projection, _) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { let mut local_correlated_cols = BTreeSet::new(); @@ -249,7 +251,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { )?; if !expr_result_map_for_count_bug.is_empty() { // has count bug - let un_matched_row = Expr::Column(Column::new_unqualified( + let un_matched_row = Expr::column(Column::new_unqualified( UN_MATCHED_ROW_INDICATOR.to_string(), )); // add the unmatched rows indicator to the Projection expressions @@ -266,7 +268,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::yes(new_plan)) } - LogicalPlan::Aggregate(aggregate) + LogicalPlan::Aggregate(aggregate, _) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { // If the aggregation is from a distinct it will not change the result for @@ -314,7 +316,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::yes(new_plan)) } - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -336,7 +338,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } Ok(Transformed::no(plan)) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input_expr_map = self .collected_count_expr_map .get(limit.input.deref()) @@ -347,7 +349,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) (true, false) => Transformed::yes(match limit.get_fetch_type()? { FetchType::Literal(Some(0)) => { - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), }) @@ -380,7 +382,7 @@ impl PullUpCorrelatedExpr { } } for col in correlated_subquery_cols.iter() { - let col_expr = Expr::Column(col.clone()); + let col_expr = Expr::column(col.clone()); if !missing_exprs.contains(&col_expr) { missing_exprs.push(col_expr) } @@ -391,9 +393,9 @@ impl PullUpCorrelatedExpr { // add to missing_exprs if not already there let contains = missing_exprs .iter() - .any(|expr| matches!(expr, Expr::Column(c) if c == col)); + .any(|expr| matches!(expr, Expr::Column(c, _) if c == col)); if !contains { - missing_exprs.push(Expr::Column(col.clone())) + missing_exprs.push(Expr::column(col.clone())) } } } @@ -402,22 +404,25 @@ impl PullUpCorrelatedExpr { } fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr + if let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) = expr { match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => + (Expr::Column(_, _), right) => !right.any_column_refs(), + (left, Expr::Column(_, _)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }, _), right) + if matches!(expr.deref(), Expr::Column(_, _)) => { !right.any_column_refs() } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => + (left, Expr::Cast(Cast { expr, .. }, _)) + if matches!(expr.deref(), Expr::Column(_, _)) => { !left.any_column_refs() } @@ -438,7 +443,7 @@ fn collect_local_correlated_cols( local_cols.extend(cols.clone()); } // SubqueryAlias is treated as the leaf node - if !matches!(child, LogicalPlan::SubqueryAlias(_)) { + if !matches!(child, LogicalPlan::SubqueryAlias(_, _)) { collect_local_correlated_cols(child, all_cols_map, local_cols); } } @@ -454,7 +459,7 @@ fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { + (Expr::BinaryExpr(a_expr, _), Expr::BinaryExpr(b_expr, _)) => { (a_expr.op == b_expr.op) && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) @@ -475,11 +480,11 @@ fn agg_exprs_evaluation_result_on_empty_batch( .clone() .transform_up(|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + Expr::AggregateFunction(expr::AggregateFunction { func, .. }, _) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::literal(ScalarValue::Int64(Some(0)))) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::literal(ScalarValue::Null)) } } _ => Transformed::no(expr), @@ -493,7 +498,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; - if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { + if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_), _)) { expr_result_map_for_count_bug .insert(e.schema_name().to_string(), result_expr); } @@ -511,7 +516,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( let result_expr = expr .clone() .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { + if let Expr::Column(Column { name, .. }, _) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { @@ -531,8 +536,8 @@ fn proj_exprs_evaluation_result_on_empty_batch( let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; let expr_name = match expr { - Expr::Alias(Alias { name, .. }) => name.to_string(), - Expr::Column(Column { relation: _, name }) => name.to_string(), + Expr::Alias(Alias { name, .. }, _) => name.to_string(), + Expr::Column(Column { relation: _, name }, _) => name.to_string(), _ => expr.schema_name().to_string(), }; expr_result_map_for_count_bug.insert(expr_name, result_expr); @@ -550,7 +555,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = filter_expr .clone() .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { + if let Expr::Column(Column { name, .. }, _) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::yes(result_expr.clone())) } else { @@ -569,10 +574,10 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(ScalarValue::Null, _) + | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true)), _) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -581,13 +586,13 @@ fn filter_exprs_evaluation_result_on_empty_batch( // can not evaluate statically _ => { for input_expr in input_expr_result_map_for_count_bug.values() { - let new_expr = Expr::Case(expr::Case { + let new_expr = Expr::case(expr::Case { expr: None, when_then_expr: vec![( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::literal(ScalarValue::Null))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 7fdad5ba4b6e9..ebecd413405e6 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -66,12 +66,12 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { + let LogicalPlan::Filter(filter, _) = plan else { return Ok(Transformed::no(plan)); }; if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = @@ -111,7 +111,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { let expr = conjunction(other_exprs); if let Some(expr) = expr { let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + cur_input = LogicalPlan::filter(new_filter); } Ok(Transformed::yes(cur_input)) } @@ -133,10 +133,13 @@ fn rewrite_inner_subqueries( let mut cur_input = outer; let alias = config.alias_generator(); let expr_without_subqueries = expr.transform(|e| match e { - Expr::Exists(Exists { - subquery: Subquery { subquery, .. }, - negated, - }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Expr::Exists( + Exists { + subquery: Subquery { subquery, .. }, + negated, + }, + _, + ) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { Some((plan, exists_expr)) => { cur_input = plan; Ok(Transformed::yes(exists_expr)) @@ -144,11 +147,14 @@ fn rewrite_inner_subqueries( None if negated => Ok(Transformed::no(not_exists(subquery))), None => Ok(Transformed::no(exists(subquery))), }, - Expr::InSubquery(InSubquery { - expr, - subquery: Subquery { subquery, .. }, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }, + _, + ) => { let in_predicate = subquery .head_output_expr()? .map_or(plan_err!("single expression required."), |output_expr| { @@ -185,27 +191,33 @@ enum SubqueryPredicate { fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { match expr { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + Expr::Not(not_expr, _) => match *not_expr { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( subquery, *expr, !negated, )), - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) } expr => SubqueryPredicate::Embedded(not(expr)), }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( subquery, *expr, negated, )), - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) } expr => SubqueryPredicate::Embedded(expr), @@ -214,7 +226,7 @@ fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { fn has_subquery(expr: &Expr) -> bool { expr.exists(|e| match e { - Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + Expr::InSubquery(_, _) | Expr::Exists(_, _) => Ok(true), _ => Ok(false), }) .unwrap() @@ -302,7 +314,7 @@ fn mark_join( ) -> Result> { let alias = alias_generator.next("__correlated_sq"); - let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark")); + let exists_col = Expr::column(Column::new(Some(alias.clone()), "mark")); let exists_expr = if negated { !exists_col } else { exists_col }; Ok( @@ -345,27 +357,33 @@ fn build_join( if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { ( Some(join_filter), - Some(Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - })), + Some(Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + )), ) => { let right_col = create_col_from_scalar_expr(right.deref(), alias)?; - let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); + let in_predicate = Expr::eq(left.deref().clone(), Expr::column(right_col)); Some(in_predicate.and(join_filter)) } (Some(join_filter), _) => Some(join_filter), ( _, - Some(Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - })), + Some(Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + )), ) => { let right_col = create_col_from_scalar_expr(right.deref(), alias)?; - let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); + let in_predicate = Expr::eq(left.deref().clone(), Expr::column(right_col)); Some(in_predicate) } _ => None, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 32b7ce44a63a5..c5715d639ca9f 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -91,24 +91,27 @@ impl OptimizerRule for EliminateCrossJoin { let mut all_inputs: Vec = vec![]; let mut all_filters: Vec = vec![]; - let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + let parent_predicate = if let LogicalPlan::Filter(filter, _) = plan { // if input isn't a join that can potentially be rewritten // avoid unwrapping the input let rewriteable = matches!( filter.input.as_ref(), - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) + LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _ + ) ); if !rewriteable { // recursively try to rewrite children - return rewrite_children(self, LogicalPlan::Filter(filter), config); + return rewrite_children(self, LogicalPlan::filter(filter), config); } if !can_flatten_join_inputs(&filter.input) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let Filter { @@ -125,10 +128,13 @@ impl OptimizerRule for EliminateCrossJoin { Some(predicate) } else if matches!( plan, - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) + LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _ + ) ) { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); @@ -160,7 +166,7 @@ impl OptimizerRule for EliminateCrossJoin { left = rewrite_children(self, left, config)?.data; if &plan_schema != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( + left = LogicalPlan::projection(Projection::new_from_schema( Arc::new(left), Arc::clone(&plan_schema), )); @@ -170,7 +176,7 @@ impl OptimizerRule for EliminateCrossJoin { // Add any filters on top - PushDownFilter can push filters down to applicable join let first = all_filters.swap_remove(0); let predicate = all_filters.into_iter().fold(first, and); - left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + left = LogicalPlan::filter(Filter::try_new(predicate, Arc::new(left))?); } let Some(predicate) = parent_predicate else { @@ -180,12 +186,12 @@ impl OptimizerRule for EliminateCrossJoin { // If there are no join keys then do nothing: if all_join_keys.is_empty() { Filter::try_new(predicate, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) + .map(|filter| Transformed::yes(LogicalPlan::filter(filter))) } else { // Remove join expressions from filter: match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + .map(|filter| Transformed::yes(LogicalPlan::filter(filter))), _ => Ok(Transformed::yes(left)), } } @@ -224,7 +230,7 @@ fn flatten_join_inputs( all_filters: &mut Vec, ) -> Result<()> { match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + LogicalPlan::Join(join, _) if join.join_type == JoinType::Inner => { if let Some(filter) = join.filter { all_filters.push(filter); } @@ -256,15 +262,18 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} + LogicalPlan::Join(join, _) if join.join_type == JoinType::Inner => {} _ => return false, }; for child in plan.inputs() { - if let LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) = child + if let LogicalPlan::Join( + Join { + join_type: JoinType::Inner, + .. + }, + _, + ) = child { if !can_flatten_join_inputs(child) { return false; @@ -321,7 +330,7 @@ fn find_inner_join( &JoinType::Inner, )?); - return Ok(LogicalPlan::Join(Join { + return Ok(LogicalPlan::join(Join { left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, @@ -343,7 +352,7 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::Join(Join { + Ok(LogicalPlan::join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, @@ -357,7 +366,7 @@ fn find_inner_join( /// Extract join keys from a WHERE clause fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }, _) = expr { match op { Operator::Eq => { // insert handles ensuring we don't add the same Join keys multiple times @@ -389,20 +398,23 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { /// * `None` otherwise fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { match expr { - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) if join_keys.contains(&left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) if join_keys.contains(&left, &right) => { // was a join key, so remove it None } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == Operator::And => { let l = remove_join_expressions(*left, join_keys); let r = remove_join_expressions(*right, join_keys); match (l, r) { - (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + (Some(ll), Some(rr)) => Some(Expr::binary_expr(BinaryExpr::new( Box::new(ll), op, Box::new(rr), @@ -412,11 +424,11 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { _ => None, } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == Operator::Or => { let l = remove_join_expressions(*left, join_keys); let r = remove_join_expressions(*right, join_keys); match (l, r) { - (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + (Some(ll), Some(rr)) => Some(Expr::binary_expr(BinaryExpr::new( Box::new(ll), op, Box::new(rr), diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 554985667fdf9..4fca0dfd4c4a6 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -63,7 +63,7 @@ impl OptimizerRule for EliminateDuplicatedExpr { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let len = sort.expr.len(); let unique_exprs: Vec<_> = sort .expr @@ -80,13 +80,13 @@ impl OptimizerRule for EliminateDuplicatedExpr { Transformed::no }; - Ok(transformed(LogicalPlan::Sort(Sort { + Ok(transformed(LogicalPlan::sort(Sort { expr: unique_exprs, input: sort.input, fetch: sort.fetch, }))) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let len = agg.group_expr.len(); let unique_exprs: Vec = agg @@ -103,7 +103,7 @@ impl OptimizerRule for EliminateDuplicatedExpr { }; Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) - .map(|f| transformed(LogicalPlan::Aggregate(f))) + .map(|f| transformed(LogicalPlan::aggregate(f))) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4e..38d4a531444d2 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -59,13 +59,16 @@ impl OptimizerRule for EliminateFilter { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), - input, - .. - }) => match v { + LogicalPlan::Filter( + Filter { + predicate: Expr::Literal(ScalarValue::Boolean(v), _), + input, + .. + }, + _, + ) => match v { Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), - Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation( + Some(false) | None => Ok(Transformed::yes(LogicalPlan::empty_relation( EmptyRelation { produce_one_row: false, schema: Arc::clone(input.schema()), @@ -111,7 +114,7 @@ mod tests { #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::literal(ScalarValue::Boolean(None)); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 13d03d647fe20..0023b2fe012ae 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -46,7 +46,7 @@ impl OptimizerRule for EliminateGroupByConstant { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Aggregate(aggregate) => { + LogicalPlan::Aggregate(aggregate, _) => { let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate .group_expr .iter() @@ -60,10 +60,10 @@ impl OptimizerRule for EliminateGroupByConstant { && nonconst_group_expr.is_empty() && aggregate.aggr_expr.is_empty()) { - return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + return Ok(Transformed::no(LogicalPlan::aggregate(aggregate))); } - let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + let simplified_aggregate = LogicalPlan::aggregate(Aggregate::try_new( aggregate.input, nonconst_group_expr.into_iter().cloned().collect(), aggregate.aggr_expr.clone(), @@ -97,12 +97,12 @@ impl OptimizerRule for EliminateGroupByConstant { /// reiles on `SimplifyExpressions` result. fn is_constant_expression(expr: &Expr) -> bool { match expr { - Expr::Alias(e) => is_constant_expression(&e.expr), - Expr::BinaryExpr(e) => { + Expr::Alias(e, _) => is_constant_expression(&e.expr), + Expr::BinaryExpr(e, _) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } - Expr::Literal(_) => true, - Expr::ScalarFunction(e) => { + Expr::Literal(_, _) => true, + Expr::ScalarFunction(e, _) => { matches!( e.func.signature().volatility, Volatility::Immutable | Volatility::Stable @@ -267,7 +267,7 @@ mod tests { Volatility::Immutable, )); let udf_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + Expr::scalar_function(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); let scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(scan) .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? @@ -292,7 +292,7 @@ mod tests { Volatility::Volatile, )); let udf_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + Expr::scalar_function(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); let scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(scan) .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dabf..e36c20d6c8983 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -52,15 +52,17 @@ impl OptimizerRule for EliminateJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { + LogicalPlan::Join(join, _) + if join.join_type == Inner && join.on.is_empty() => + { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( - Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( + Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: join.schema, })), ), - _ => Ok(Transformed::no(LogicalPlan::Join(join))), + _ => Ok(Transformed::no(LogicalPlan::join(join))), } } _ => Ok(Transformed::no(plan)), diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 267615c3e0d93..d47aa3a48ec17 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -59,15 +59,15 @@ impl OptimizerRule for EliminateLimit { _config: &dyn OptimizerConfig, ) -> Result, datafusion_common::DataFusionError> { match plan { - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Only supports rewriting for literal fetch let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; if let Some(v) = fetch { if v == 0 { - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + return Ok(Transformed::yes(LogicalPlan::empty_relation( EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), @@ -79,7 +79,7 @@ impl OptimizerRule for EliminateLimit { // we can remove it. Its input also can be Limit, so we should apply again. return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } - Ok(Transformed::no(LogicalPlan::Limit(limit))) + Ok(Transformed::no(LogicalPlan::limit(limit))) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78f..4979ddc2f3ac1 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -55,21 +55,21 @@ impl OptimizerRule for EliminateNestedUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema, }))) } - LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { match Arc::unwrap_or_clone(nested_plan) { - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() .map(extract_plan_from_distinct) @@ -77,14 +77,14 @@ impl OptimizerRule for EliminateNestedUnion { .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( - Arc::new(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::distinct(Distinct::All( + Arc::new(LogicalPlan::union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema: Arc::clone(&schema), })), )))) } - nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( + nested_plan => Ok(Transformed::no(LogicalPlan::distinct( Distinct::All(Arc::new(nested_plan)), ))), } @@ -96,7 +96,7 @@ impl OptimizerRule for EliminateNestedUnion { fn extract_plans_from_union(plan: Arc) -> Vec { match Arc::unwrap_or_clone(plan) { - LogicalPlan::Union(Union { inputs, .. }) => inputs + LogicalPlan::Union(Union { inputs, .. }, _) => inputs .into_iter() .map(Arc::unwrap_or_clone) .collect::>(), @@ -106,7 +106,7 @@ fn extract_plans_from_union(plan: Arc) -> Vec { fn extract_plan_from_distinct(plan: Arc) -> Arc { match Arc::unwrap_or_clone(plan) { - LogicalPlan::Distinct(Distinct::All(plan)) => plan, + LogicalPlan::Distinct(Distinct::All(plan), _) => plan, plan => Arc::new(plan), } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 3e027811420c4..ac3da4e8f65d8 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -50,7 +50,7 @@ impl OptimizerRule for EliminateOneUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok( + LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => Ok( Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), ), _ => Ok(Transformed::no(plan)), @@ -110,7 +110,7 @@ mod tests { &schema().to_dfschema()?, )?; let schema = Arc::clone(table_plan.schema()); - let single_union_plan = LogicalPlan::Union(Union { + let single_union_plan = LogicalPlan::union(Union { inputs: vec![Arc::new(table_plan)], schema, }); diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a435..bca5d61d4c449 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -78,56 +78,58 @@ impl OptimizerRule for EliminateOuterJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Join(join) => { - let mut non_nullable_cols: Vec = vec![]; + LogicalPlan::Filter(mut filter, _) => { + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Join(join, _) => { + let mut non_nullable_cols: Vec = vec![]; - extract_non_nullable_columns( - &filter.predicate, - &mut non_nullable_cols, - join.left.schema(), - join.right.schema(), - true, - ); + extract_non_nullable_columns( + &filter.predicate, + &mut non_nullable_cols, + join.left.schema(), + join.right.schema(), + true, + ); - let new_join_type = if join.join_type.is_outer() { - let mut left_non_nullable = false; - let mut right_non_nullable = false; - for col in non_nullable_cols.iter() { - if join.left.schema().has_column(col) { - left_non_nullable = true; - } - if join.right.schema().has_column(col) { - right_non_nullable = true; + let new_join_type = if join.join_type.is_outer() { + let mut left_non_nullable = false; + let mut right_non_nullable = false; + for col in non_nullable_cols.iter() { + if join.left.schema().has_column(col) { + left_non_nullable = true; + } + if join.right.schema().has_column(col) { + right_non_nullable = true; + } } - } - eliminate_outer( - join.join_type, - left_non_nullable, - right_non_nullable, - ) - } else { - join.join_type - }; + eliminate_outer( + join.join_type, + left_non_nullable, + right_non_nullable, + ) + } else { + join.join_type + }; - let new_join = Arc::new(LogicalPlan::Join(Join { - left: join.left, - right: join.right, - join_type: new_join_type, - join_constraint: join.join_constraint, - on: join.on.clone(), - filter: join.filter.clone(), - schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, - })); - Filter::try_new(filter.predicate, new_join) - .map(|f| Transformed::yes(LogicalPlan::Filter(f))) - } - filter_input => { - filter.input = Arc::new(filter_input); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + let new_join = Arc::new(LogicalPlan::join(Join { + left: join.left, + right: join.right, + join_type: new_join_type, + join_constraint: join.join_constraint, + on: join.on.clone(), + filter: join.filter.clone(), + schema: Arc::clone(&join.schema), + null_equals_null: join.null_equals_null, + })); + Filter::try_new(filter.predicate, new_join) + .map(|f| Transformed::yes(LogicalPlan::filter(f))) + } + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::filter(filter))) + } } - }, + } _ => Ok(Transformed::no(plan)), } } @@ -180,10 +182,10 @@ fn extract_non_nullable_columns( top_level: bool, ) { match expr { - Expr::Column(col) => { + Expr::Column(col, _) => { non_nullable_cols.push(col.clone()); } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => match op { // If one of the inputs are null for these operators, the results should be false. Operator::Eq | Operator::NotEq @@ -270,14 +272,14 @@ fn extract_non_nullable_columns( } _ => {} }, - Expr::Not(arg) => extract_non_nullable_columns( + Expr::Not(arg, _) => extract_non_nullable_columns( arg, non_nullable_cols, left_schema, right_schema, false, ), - Expr::IsNotNull(arg) => { + Expr::IsNotNull(arg, _) => { if !top_level { return; } @@ -289,14 +291,16 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( - expr, - non_nullable_cols, - left_schema, - right_schema, - false, - ), + Expr::Cast(Cast { expr, data_type: _ }, _) + | Expr::TryCast(TryCast { expr, data_type: _ }, _) => { + extract_non_nullable_columns( + expr, + non_nullable_cols, + left_schema, + right_schema, + false, + ) + } _ => {} } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 48191ec206313..16c3355c3b8f6 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -67,16 +67,19 @@ impl OptimizerRule for ExtractEquijoinPredicate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Join(Join { - left, - right, - mut on, - filter: Some(expr), - join_type, - join_constraint, - schema, - null_equals_null, - }) => { + LogicalPlan::Join( + Join { + left, + right, + mut on, + filter: Some(expr), + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => { let left_schema = left.schema(); let right_schema = right.schema(); let (equijoin_predicates, non_equijoin_expr) = @@ -84,7 +87,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { if !equijoin_predicates.is_empty() { on.extend(equijoin_predicates); - Ok(Transformed::yes(LogicalPlan::Join(Join { + Ok(Transformed::yes(LogicalPlan::join(Join { left, right, on, @@ -95,7 +98,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { null_equals_null, }))) } else { - Ok(Transformed::no(LogicalPlan::Join(Join { + Ok(Transformed::no(LogicalPlan::join(Join { left, right, on, @@ -123,11 +126,14 @@ fn split_eq_and_noneq_join_predicate( let mut accum_filters: Vec = vec![]; for expr in exprs { match expr { - Expr::BinaryExpr(BinaryExpr { - ref left, - op: Operator::Eq, - ref right, - }) => { + Expr::BinaryExpr( + BinaryExpr { + ref left, + op: Operator::Eq, + ref right, + }, + _, + ) => { let join_key_pair = find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c57..3f190ff326673 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -50,7 +50,7 @@ impl OptimizerRule for FilterNullJoinKeys { return Ok(Transformed::no(plan)); } match plan { - LogicalPlan::Join(mut join) + LogicalPlan::Join(mut join, _) if !join.on.is_empty() && !join.null_equals_null => { let (left_preserved, right_preserved) = @@ -74,17 +74,17 @@ impl OptimizerRule for FilterNullJoinKeys { if !left_filters.is_empty() { let predicate = create_not_null_predicate(left_filters); - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.left = Arc::new(LogicalPlan::filter(Filter::try_new( predicate, join.left, )?)); } if !right_filters.is_empty() { let predicate = create_not_null_predicate(right_filters); - join.right = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.right = Arc::new(LogicalPlan::filter(Filter::try_new( predicate, join.right, )?)); } - Ok(Transformed::yes(LogicalPlan::Join(join))) + Ok(Transformed::yes(LogicalPlan::join(join))) } _ => Ok(Transformed::no(plan)), } @@ -95,10 +95,7 @@ impl OptimizerRule for FilterNullJoinKeys { } fn create_not_null_predicate(filters: Vec) -> Expr { - let not_null_exprs: Vec = filters - .into_iter() - .map(|c| Expr::IsNotNull(Box::new(c))) - .collect(); + let not_null_exprs: Vec = filters.into_iter().map(Expr::is_not_null).collect(); // directly unwrap since it should always have a value conjunction(not_null_exprs).unwrap() diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 1519c54dbf68a..0a6ceeb903fe7 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -119,12 +119,12 @@ fn optimize_projections( // Recursively rewrite any nodes that may be able to avoid computation given // their parents' required indices. match plan { - LogicalPlan::Projection(proj) => { + LogicalPlan::Projection(proj, _) => { return merge_consecutive_projections(proj)?.transform_data(|proj| { rewrite_projection_given_requirements(proj, config, &indices) }) } - LogicalPlan::Aggregate(aggregate) => { + LogicalPlan::Aggregate(aggregate, _) => { // Split parent requirements to GROUP BY and aggregate sections: let n_group_exprs = aggregate.group_expr_len()?; // Offset aggregate indices so that they point to valid indices at @@ -200,10 +200,10 @@ fn optimize_projections( new_group_bys, new_aggr_expr, ) - .map(LogicalPlan::Aggregate) + .map(LogicalPlan::aggregate) }); } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { let input_schema = Arc::clone(window.input.schema()); // Split parent requirements to child and window expression sections: let n_input_fields = input_schema.fields().len(); @@ -238,12 +238,12 @@ fn optimize_projections( add_projection_on_top_if_helpful(window_child, required_exprs)? .data; Window::try_new(new_window_expr, Arc::new(window_child)) - .map(LogicalPlan::Window) + .map(LogicalPlan::window) .map(Transformed::yes) } }); } - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { let TableScan { table_name, source, @@ -266,7 +266,7 @@ fn optimize_projections( filters, fetch, ) - .map(LogicalPlan::TableScan) + .map(LogicalPlan::table_scan) .map(Transformed::yes); } // Other node types are handled below @@ -276,12 +276,12 @@ fn optimize_projections( // For other plan node types, calculate indices for columns they use and // try to rewrite their children let mut child_required_indices: Vec = match &plan { - LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Distinct(Distinct::On(_)) => { + LogicalPlan::Sort(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Distinct(Distinct::On(_), _) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. All these // operators benefit from "small" inputs, so the projection_beneficial @@ -296,7 +296,7 @@ fn optimize_projections( }) .collect::>()? } - LogicalPlan::Limit(_) => { + LogicalPlan::Limit(_, _) => { // Pass index requirements from the parent as well as column indices // that appear in this plan's expressions to its child. These operators // do not benefit from "small" inputs, so the projection_beneficial @@ -306,14 +306,14 @@ fn optimize_projections( .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) .collect::>()? } - LogicalPlan::Copy(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Distinct(Distinct::All(_)) => { + LogicalPlan::Copy(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Dml(_, _) + | LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::Distinct(Distinct::All(_), _) => { // These plans require all their fields, and their children should // be treated as final plans -- otherwise, we may have schema a // mismatch. @@ -324,7 +324,7 @@ fn optimize_projections( .map(RequiredIndicies::new_for_all_exprs) .collect() } - LogicalPlan::Extension(extension) => { + LogicalPlan::Extension(extension, _) => { let Some(necessary_children_indices) = extension.node.necessary_children_exprs(indices.indices()) else { @@ -346,14 +346,14 @@ fn optimize_projections( }) .collect::>>()? } - LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Values(_) - | LogicalPlan::DescribeTable(_) => { + LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::DescribeTable(_, _) => { // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let left_len = join.left.schema().fields().len(); let (left_req_indices, right_req_indices) = split_join_requirements(left_len, indices, &join.join_type); @@ -369,17 +369,20 @@ fn optimize_projections( ] } // these nodes are explicitly rewritten in the match statement above - LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Window(_) - | LogicalPlan::TableScan(_) => { + LogicalPlan::Projection(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::TableScan(_, _) => { return internal_err!( "OptimizeProjection: should have handled in the match statement above" ); } - LogicalPlan::Unnest(Unnest { - dependency_indices, .. - }) => { + LogicalPlan::Unnest( + Unnest { + dependency_indices, .. + }, + _, + ) => { vec![RequiredIndicies::new_from_indices( dependency_indices.clone(), )] @@ -452,7 +455,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result Result rewrite_expr(*expr, &prev_projection).map(|result| { - result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| expr.alias_qualified(relation, name)) }), e => rewrite_expr(e, &prev_projection), } @@ -513,7 +519,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_, _) | Expr::Literal(_, _)) } /// Rewrites a projection expression using the projection before it (i.e. its input) @@ -572,8 +578,8 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), - Expr::Column(col) => { + Expr::Alias(alias, _) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col, _) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression @@ -604,16 +610,16 @@ fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly expr.apply(|expr| { match expr { - Expr::OuterReferenceColumn(_, col) => { + Expr::OuterReferenceColumn(_, col, _) => { columns.insert(col); } - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { outer_columns_helper_multi(&subquery.outer_ref_columns, columns); } - Expr::Exists(exists) => { + Expr::Exists(exists, _) => { outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); } - Expr::InSubquery(insubquery) => { + Expr::InSubquery(insubquery, _) => { outer_columns_helper_multi( &insubquery.subquery.outer_ref_columns, columns, @@ -721,7 +727,7 @@ fn add_projection_on_top_if_helpful( Ok(Transformed::no(plan)) } else { Projection::try_new(project_exprs, Arc::new(plan)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Transformed::yes) } } @@ -763,7 +769,7 @@ fn rewrite_projection_given_requirements( Ok(Transformed::yes(input)) } else { Projection::try_new(exprs_used, Arc::new(input)) - .map(LogicalPlan::Projection) + .map(LogicalPlan::projection) .map(Transformed::yes) } }) @@ -1208,7 +1214,7 @@ mod tests { let expr = Box::new(col("a")); let pattern = Box::new(lit("[0-9]")); let similar_to_expr = - Expr::SimilarTo(Like::new(false, expr, pattern, None, false)); + Expr::similar_to(Like::new(false, expr, pattern, None, false)); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![similar_to_expr])? .build()?; @@ -1276,7 +1282,7 @@ mod tests { #[test] fn test_user_defined_logical_plan_node() -> Result<()> { let table_scan = test_table_scan()?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoOpUserDefined::new( Arc::clone(table_scan.schema()), Arc::new(table_scan.clone()), @@ -1299,8 +1305,8 @@ mod tests { #[test] fn test_user_defined_logical_plan_node2() -> Result<()> { let table_scan = test_table_scan()?; - let exprs = vec![Expr::Column(Column::from_qualified_name("b"))]; - let custom_plan = LogicalPlan::Extension(Extension { + let exprs = vec![Expr::column(Column::from_qualified_name("b"))]; + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new( NoOpUserDefined::new( Arc::clone(table_scan.schema()), @@ -1327,15 +1333,15 @@ mod tests { #[test] fn test_user_defined_logical_plan_node3() -> Result<()> { let table_scan = test_table_scan()?; - let left_expr = Expr::Column(Column::from_qualified_name("b")); - let right_expr = Expr::Column(Column::from_qualified_name("c")); - let binary_expr = Expr::BinaryExpr(BinaryExpr::new( + let left_expr = Expr::column(Column::from_qualified_name("b")); + let right_expr = Expr::column(Column::from_qualified_name("c")); + let binary_expr = Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), Operator::Plus, Box::new(right_expr), )); let exprs = vec![binary_expr]; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new( NoOpUserDefined::new( Arc::clone(table_scan.schema()), @@ -1362,7 +1368,7 @@ mod tests { fn test_user_defined_logical_plan_node4() -> Result<()> { let left_table = test_table_scan_with_name("l")?; let right_table = test_table_scan_with_name("r")?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(UserDefinedCrossJoin::new( Arc::new(left_table), Arc::new(right_table), @@ -1691,7 +1697,7 @@ mod tests { let table_scan = test_table_scan()?; let projection = LogicalPlanBuilder::from(table_scan) - .project(vec![Expr::Cast(Cast::new( + .project(vec![Expr::cast(Cast::new( Box::new(col("c")), DataType::Float64, ))])? @@ -1731,7 +1737,7 @@ mod tests { // relation is `None`). PlanBuilder resolves the expressions let expr = vec![col("test.a"), col("test.b")]; let plan = - LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); + LogicalPlan::projection(Projection::try_new(expr, Arc::new(table_scan))?); assert_fields_eq(&plan, vec!["a", "b"]); @@ -1942,7 +1948,7 @@ mod tests { fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1950,7 +1956,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 975150cd61220..084a3eea16174 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -496,7 +496,7 @@ mod tests { fn skip_failing_rule() { let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let config = OptimizerContext::new().with_skip_failing_rules(true); - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); @@ -507,7 +507,7 @@ mod tests { fn no_skip_failing_rule() { let opt = Optimizer::with_rules(vec![Arc::new(BadRule {})]); let config = OptimizerContext::new().with_skip_failing_rules(false); - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); @@ -523,7 +523,7 @@ mod tests { fn generate_different_schema() { let opt = Optimizer::with_rules(vec![Arc::new(GetTableScanRule {})]); let config = OptimizerContext::new().with_skip_failing_rules(false); - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); @@ -556,7 +556,7 @@ mod tests { fn skip_generate_different_schema() { let opt = Optimizer::with_rules(vec![Arc::new(GetTableScanRule {})]); let config = OptimizerContext::new().with_skip_failing_rules(true); - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); @@ -573,7 +573,7 @@ mod tests { let input = Arc::new(test_table_scan()?); let input_schema = Arc::clone(input.schema()); - let plan = LogicalPlan::Projection(Projection::try_new_with_schema( + let plan = LogicalPlan::projection(Projection::try_new_with_schema( vec![col("a"), col("b"), col("c")], input, add_metadata_to_fields(input_schema.as_ref()), @@ -740,7 +740,7 @@ mod tests { _config: &dyn OptimizerConfig, ) -> Result> { let projection = match plan { - LogicalPlan::Projection(p) if p.expr.len() >= 2 => p, + LogicalPlan::Projection(p, _) if p.expr.len() >= 2 => p, _ => return Ok(Transformed::no(plan)), }; @@ -754,7 +754,7 @@ mod tests { exprs.rotate_left(1); } - Ok(Transformed::yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::projection( Projection::try_new(exprs, Arc::clone(&projection.input))?, ))) } diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 73e6b418272a9..95f72f5535f69 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -97,21 +97,22 @@ mod tests { fn node_number_for_some_plan() -> Result<()> { let schema = Arc::new(DFSchema::empty()); - let one_node_plan = - Arc::new(LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + let one_node_plan = Arc::new(LogicalPlan::empty_relation( + datafusion_expr::EmptyRelation { produce_one_row: false, schema: Arc::clone(&schema), - })); + }, + )); assert_eq!(1, get_node_number(&one_node_plan).get()); - let two_node_plan = Arc::new(LogicalPlan::Projection( + let two_node_plan = Arc::new(LogicalPlan::projection( datafusion_expr::Projection::try_new(vec![lit(1), lit(2)], one_node_plan)?, )); assert_eq!(2, get_node_number(&two_node_plan).get()); - let five_node_plan = Arc::new(LogicalPlan::Union(datafusion_expr::Union { + let five_node_plan = Arc::new(LogicalPlan::union(datafusion_expr::Union { inputs: vec![Arc::clone(&two_node_plan), two_node_plan], schema, })); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index d26df073dc6fd..8f38970d1dc4f 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -58,21 +58,21 @@ impl OptimizerRule for PropagateEmptyRelation { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::EmptyRelation(_) => Ok(Transformed::no(plan)), - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Sort(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Limit(_) => { + LogicalPlan::EmptyRelation(_, _) => Ok(Transformed::no(plan)), + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Limit(_, _) => { let empty = empty_child(&plan)?; if let Some(empty_plan) = empty { return Ok(Transformed::yes(empty_plan)); } Ok(Transformed::no(plan)) } - LogicalPlan::Join(ref join) => { + LogicalPlan::Join(ref join, _) => { // TODO: For Join, more join type need to be careful: // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side // columns + right side columns replaced with null values. @@ -83,43 +83,43 @@ impl OptimizerRule for PropagateEmptyRelation { match join.join_type { // For Full Join, only both sides are empty, the Join result is empty. JoinType::Full if left_empty && right_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), )), JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), )), JoinType::Left if left_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), )), JoinType::Right if right_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), )), JoinType::LeftSemi if left_empty || right_empty => Ok( - Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), })), ), JoinType::RightSemi if left_empty || right_empty => Ok( - Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), })), ), JoinType::LeftAnti if left_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), @@ -131,7 +131,7 @@ impl OptimizerRule for PropagateEmptyRelation { Ok(Transformed::yes((*join.right).clone())) } JoinType::RightAnti if right_empty => Ok(Transformed::yes( - LogicalPlan::EmptyRelation(EmptyRelation { + LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), @@ -139,20 +139,20 @@ impl OptimizerRule for PropagateEmptyRelation { _ => Ok(Transformed::no(plan)), } } - LogicalPlan::Aggregate(ref agg) => { + LogicalPlan::Aggregate(ref agg, _) => { if !agg.group_expr.is_empty() { if let Some(empty_plan) = empty_child(&plan)? { return Ok(Transformed::yes(empty_plan)); } } - Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone()))) + Ok(Transformed::no(LogicalPlan::aggregate(agg.clone()))) } - LogicalPlan::Union(ref union) => { + LogicalPlan::Union(ref union, _) => { let new_inputs = union .inputs .iter() .filter(|input| match &***input { - LogicalPlan::EmptyRelation(empty) => empty.produce_one_row, + LogicalPlan::EmptyRelation(empty, _) => empty.produce_one_row, _ => true, }) .cloned() @@ -161,7 +161,7 @@ impl OptimizerRule for PropagateEmptyRelation { if new_inputs.len() == union.inputs.len() { Ok(Transformed::no(plan)) } else if new_inputs.is_empty() { - Ok(Transformed::yes(LogicalPlan::EmptyRelation( + Ok(Transformed::yes(LogicalPlan::empty_relation( EmptyRelation { produce_one_row: false, schema: Arc::clone(plan.schema()), @@ -174,7 +174,7 @@ impl OptimizerRule for PropagateEmptyRelation { if child.schema().eq(plan.schema()) { Ok(Transformed::yes(child)) } else { - Ok(Transformed::yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::projection( Projection::new_from_schema( Arc::new(child), Arc::clone(plan.schema()), @@ -182,7 +182,7 @@ impl OptimizerRule for PropagateEmptyRelation { ))) } } else { - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs: new_inputs, schema: Arc::clone(&union.schema), }))) @@ -198,11 +198,11 @@ fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { match plan.inputs()[..] { [left, right] => { let left_empty = match left { - LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row, + LogicalPlan::EmptyRelation(empty, _) => !empty.produce_one_row, _ => false, }; let right_empty = match right { - LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row, + LogicalPlan::EmptyRelation(empty, _) => !empty.produce_one_row, _ => false, }; Ok((left_empty, right_empty)) @@ -214,9 +214,9 @@ fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { fn empty_child(plan: &LogicalPlan) -> Result> { match plan.inputs()[..] { [child] => match child { - LogicalPlan::EmptyRelation(empty) => { + LogicalPlan::EmptyRelation(empty, _) => { if !empty.produce_one_row { - Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + Ok(Some(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::clone(plan.schema()), }))) @@ -564,7 +564,7 @@ mod tests { let fields = test_table_scan_fields(); - let empty = LogicalPlan::EmptyRelation(EmptyRelation { + let empty = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::from_unqualified_fields( fields.into(), diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 195dc06578b2b..905ec2cc61460 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -252,42 +252,42 @@ fn schema_columns(schema: &DFSchema) -> HashSet { fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { - Expr::Column(_) - | Expr::Literal(_) - | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), + Expr::Column(_, _) + | Expr::Literal(_, _) + | Expr::Placeholder(_, _) + | Expr::ScalarVariable(_, _, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::Unnest(_) => { + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) + | Expr::OuterReferenceColumn(_, _, _) + | Expr::Unnest(_, _) => { is_evaluate = false; Ok(TreeNodeRecursion::Stop) } - Expr::Alias(_) - | Expr::BinaryExpr(_) - | Expr::Like(_) - | Expr::SimilarTo(_) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::Between(_) - | Expr::Case(_) - | Expr::Cast(_) - | Expr::TryCast(_) + Expr::Alias(_, _) + | Expr::BinaryExpr(_, _) + | Expr::Like(_, _) + | Expr::SimilarTo(_, _) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) + | Expr::Between(_, _) + | Expr::Case(_, _) + | Expr::Cast(_, _) + | Expr::TryCast(_, _) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), - Expr::AggregateFunction(_) - | Expr::WindowFunction(_) + | Expr::ScalarFunction(_, _) => Ok(TreeNodeRecursion::Continue), + Expr::AggregateFunction(_, _) + | Expr::WindowFunction(_, _) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_, _) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -333,11 +333,14 @@ fn extract_or_clauses_for_join<'a>( // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Or, - right, - }) = expr + if let Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Or, + right, + }, + _, + ) = expr { let left_expr = extract_or_clause(left.as_ref(), &schema_columns); let right_expr = extract_or_clause(right.as_ref(), &schema_columns); @@ -366,11 +369,14 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { + Expr::BinaryExpr( + BinaryExpr { + left: l_expr, + op: Operator::Or, + right: r_expr, + }, + _, + ) => { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); @@ -378,11 +384,14 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { + Expr::BinaryExpr( + BinaryExpr { + left: l_expr, + op: Operator::And, + right: r_expr, + }, + _, + ) => { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); @@ -498,11 +507,11 @@ fn push_down_all_join( } if let Some(predicate) = conjunction(left_push) { - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); + join.left = Arc::new(LogicalPlan::filter(Filter::try_new(predicate, join.left)?)); } if let Some(predicate) = conjunction(right_push) { join.right = - Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + Arc::new(LogicalPlan::filter(Filter::try_new(predicate, join.right)?)); } // Add any new join conditions as the non join predicates @@ -510,9 +519,9 @@ fn push_down_all_join( join.filter = conjunction(join_conditions); // wrap the join on the filter whose predicates must be kept, if any - let plan = LogicalPlan::Join(join); + let plan = LogicalPlan::join(join); let plan = if let Some(predicate) = conjunction(keep_predicates) { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + LogicalPlan::filter(Filter::try_new(predicate, Arc::new(plan))?) } else { plan }; @@ -541,7 +550,7 @@ fn push_down_join( && predicates.is_empty() && inferred_join_predicates.is_empty() { - return Ok(Transformed::no(LogicalPlan::Join(join))); + return Ok(Transformed::no(LogicalPlan::join(join))); } push_down_all_join(predicates, inferred_join_predicates, join, on_filters) @@ -765,18 +774,18 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - if let LogicalPlan::Join(join) = plan { + if let LogicalPlan::Join(join, _) = plan { return push_down_join(join, None); }; let plan_schema = Arc::clone(plan.schema()); - let LogicalPlan::Filter(mut filter) = plan else { + let LogicalPlan::Filter(mut filter, _) = plan else { return Ok(Transformed::no(plan)); }; match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Filter(child_filter) => { + LogicalPlan::Filter(child_filter, _) => { let parents_predicates = split_conjunction_owned(filter.predicate); // remove duplicated filters @@ -792,31 +801,31 @@ impl OptimizerRule for PushDownFilter { let Some(new_predicate) = conjunction(new_predicates) else { return plan_err!("at least one expression exists"); }; - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( new_predicate, child_filter.input, )?); self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(repartition) => { + LogicalPlan::Repartition(repartition, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Repartition(repartition), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::repartition(repartition), new_filter) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(distinct.input())) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Distinct(distinct), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::distinct(distinct), new_filter) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let new_filter = Filter::try_new(filter.predicate, Arc::clone(&sort.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Sort(sort), new_filter) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::sort(sort), new_filter) } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -825,18 +834,18 @@ impl OptimizerRule for PushDownFilter { subquery_alias.schema.qualified_field(i); replace_map.insert( qualified_name(sub_qualifier, sub_field.name()), - Expr::Column(Column::new(qualifier.cloned(), field.name())), + Expr::column(Column::new(qualifier.cloned(), field.name())), ); } let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( new_predicate, Arc::clone(&subquery_alias.input), )?); - insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) + insert_below(LogicalPlan::subquery_alias(subquery_alias), new_filter) } - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { let predicates = split_conjunction_owned(filter.predicate.clone()); let (new_projection, keep_predicate) = rewrite_projection(predicates, projection)?; @@ -845,15 +854,15 @@ impl OptimizerRule for PushDownFilter { None => Ok(new_projection), Some(keep_predicate) => new_projection.map_data(|child_plan| { Filter::try_new(keep_predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) + .map(LogicalPlan::filter) }), } } else { filter.input = Arc::new(new_projection.data); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::no(LogicalPlan::filter(filter))) } } - LogicalPlan::Unnest(mut unnest) => { + LogicalPlan::Unnest(mut unnest, _) => { let predicates = split_conjunction_owned(filter.predicate.clone()); let mut non_unnest_predicates = vec![]; let mut unnest_predicates = vec![]; @@ -874,8 +883,8 @@ impl OptimizerRule for PushDownFilter { // Unnest predicates should not be pushed down. // If no non-unnest predicates exist, early return if non_unnest_predicates.is_empty() { - filter.input = Arc::new(LogicalPlan::Unnest(unnest)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + filter.input = Arc::new(LogicalPlan::unnest(unnest)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } // Push down non-unnest filter predicate @@ -888,7 +897,7 @@ impl OptimizerRule for PushDownFilter { let unnest_input = std::mem::take(&mut unnest.input); - let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + let filter_with_unnest_input = LogicalPlan::filter(Filter::try_new( conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. unnest_input, )?); @@ -897,16 +906,16 @@ impl OptimizerRule for PushDownFilter { // The new filter plan will go through another rewrite pass since the rule itself // is applied recursively to all the child from top to down let unnest_plan = - insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; + insert_below(LogicalPlan::unnest(unnest), filter_with_unnest_input)?; match conjunction(unnest_predicates) { None => Ok(unnest_plan), - Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( + Some(predicate) => Ok(Transformed::yes(LogicalPlan::filter( Filter::try_new(predicate, Arc::new(unnest_plan.data))?, ))), } } - LogicalPlan::Union(ref union) => { + LogicalPlan::Union(ref union, _) => { let mut inputs = Vec::with_capacity(union.inputs.len()); for input in &union.inputs { let mut replace_map = HashMap::new(); @@ -915,23 +924,23 @@ impl OptimizerRule for PushDownFilter { union.schema.qualified_field(i); replace_map.insert( qualified_name(union_qualifier, union_field.name()), - Expr::Column(Column::new(qualifier.cloned(), field.name())), + Expr::column(Column::new(qualifier.cloned(), field.name())), ); } let push_predicate = replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( + inputs.push(Arc::new(LogicalPlan::filter(Filter::try_new( push_predicate, Arc::clone(input), )?))) } - Ok(Transformed::yes(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::union(Union { inputs, schema: Arc::clone(&plan_schema), }))) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -965,7 +974,7 @@ impl OptimizerRule for PushDownFilter { .collect::>>()?; let agg_input = Arc::clone(&agg.input); - Transformed::yes(LogicalPlan::Aggregate(agg)) + Transformed::yes(LogicalPlan::aggregate(agg)) .transform_data(|new_plan| { // If we have a filter to push, we push it down to the input of the aggregate if let Some(predicate) = conjunction(replaced_push_predicates) { @@ -985,8 +994,8 @@ impl OptimizerRule for PushDownFilter { } }) } - LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::TableScan(scan) => { + LogicalPlan::Join(join, _) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::TableScan(scan, _) => { let filter_predicates = split_conjunction(&filter.predicate); let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = @@ -1030,7 +1039,7 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - let new_scan = LogicalPlan::TableScan(TableScan { + let new_scan = LogicalPlan::table_scan(TableScan { filters: new_scan_filters, ..scan }); @@ -1043,7 +1052,7 @@ impl OptimizerRule for PushDownFilter { } }) } - LogicalPlan::Extension(extension_plan) => { + LogicalPlan::Extension(extension_plan, _) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); @@ -1064,8 +1073,8 @@ impl OptimizerRule for PushDownFilter { // all predicates are kept, no changes needed if predicate_push_or_keep.iter().all(|&x| !x) { - filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + filter.input = Arc::new(LogicalPlan::extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } // going to push some predicates down, so split the predicates @@ -1088,7 +1097,7 @@ impl OptimizerRule for PushDownFilter { .inputs() .into_iter() .map(|child| { - Ok(LogicalPlan::Filter(Filter::try_new( + Ok(LogicalPlan::filter(Filter::try_new( predicate.clone(), Arc::new(child.clone()), )?)) @@ -1097,12 +1106,12 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let child_plan = LogicalPlan::Extension(extension_plan); + let child_plan = LogicalPlan::extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; let new_plan = match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( + Some(predicate) => LogicalPlan::filter(Filter::try_new( predicate, Arc::new(new_extension), )?), @@ -1112,7 +1121,7 @@ impl OptimizerRule for PushDownFilter { } child => { filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::no(LogicalPlan::filter(filter))) } } } @@ -1178,7 +1187,7 @@ fn rewrite_projection( Some(expr) => { // re-write all filters based on this projection // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = LogicalPlan::filter(Filter::try_new( replace_cols_by_name(expr, &non_volatile_map)?, std::mem::take(&mut projection.input), )?); @@ -1186,17 +1195,17 @@ fn rewrite_projection( projection.input = Arc::new(new_filter); Ok(( - Transformed::yes(LogicalPlan::Projection(projection)), + Transformed::yes(LogicalPlan::projection(projection)), conjunction(keep_predicates), )) } - None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), + None => Ok((Transformed::no(LogicalPlan::projection(projection)), None)), } } /// Creates a new LogicalPlan::Filter node. pub fn make_filter(predicate: Expr, input: Arc) -> Result { - Filter::try_new(predicate, input).map(LogicalPlan::Filter) + Filter::try_new(predicate, input).map(LogicalPlan::filter) } /// Replace the existing child of the single input node with `new_child`. @@ -1247,7 +1256,7 @@ pub fn replace_cols_by_name( replace_map: &HashMap, ) -> Result { e.transform_up(|expr| { - Ok(if let Expr::Column(c) = &expr { + Ok(if let Expr::Column(c, _) = &expr { match replace_map.get(&c.flat_name()) { Some(new_c) => Transformed::yes(new_c.clone()), None => Transformed::no(expr), @@ -1263,7 +1272,7 @@ pub fn replace_cols_by_name( fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; e.apply(|expr| { - Ok(if let Expr::Column(c) = &expr { + Ok(if let Expr::Column(c, _) = &expr { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; @@ -1458,7 +1467,7 @@ mod tests { } fn add(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Plus, Box::new(right), @@ -1466,7 +1475,7 @@ mod tests { } fn multiply(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left), Operator::Multiply, Box::new(right), @@ -1594,7 +1603,7 @@ mod tests { fn user_defined_plan() -> Result<()> { let table_scan = test_table_scan()?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1610,7 +1619,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1627,7 +1636,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -1644,7 +1653,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(plan, expected)?; - let custom_plan = LogicalPlan::Extension(Extension { + let custom_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -2536,7 +2545,7 @@ mod tests { ) -> Result { let test_provider = PushDownProvider { filter_support }; - let table_scan = LogicalPlan::TableScan(TableScan { + let table_scan = LogicalPlan::table_scan(TableScan { table_name: "test".into(), filters, projected_schema: Arc::new(DFSchema::try_from( @@ -3313,7 +3322,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? @@ -3347,7 +3356,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let left = LogicalPlanBuilder::from(table_scan).build()?; let right_table_scan = test_table_scan_with_name("test2")?; let right = LogicalPlanBuilder::from(right_table_scan).build()?; @@ -3393,7 +3402,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("b")])? .filter(expr.gt(lit(0.1)))? @@ -3417,7 +3426,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("b")])? .filter( @@ -3444,7 +3453,7 @@ Projection: a, b let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), }); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let expr = Expr::scalar_function(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = table_scan_with_pushdown_provider_builder( TableProviderFilterPushDown::Unsupported, vec![], diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 8a3aa4bb84599..ed526e950eb28 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -53,29 +53,29 @@ impl OptimizerRule for PushDownLimit { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let LogicalPlan::Limit(mut limit) = plan else { + let LogicalPlan::Limit(mut limit, _) = plan else { return Ok(Transformed::no(plan)); }; // Currently only rewrite if skip and fetch are both literals let SkipType::Literal(skip) = limit.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = limit.input.as_ref() { + if let LogicalPlan::Limit(child, _) = limit.input.as_ref() { let SkipType::Literal(child_skip) = child.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); - let plan = LogicalPlan::Limit(Limit { + let plan = LogicalPlan::limit(Limit { skip: Some(Box::new(lit(skip as i64))), fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), @@ -87,74 +87,76 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(LogicalPlan::limit(limit))); }; match Arc::unwrap_or_clone(limit.input) { - LogicalPlan::TableScan(mut scan) => { + LogicalPlan::TableScan(mut scan, _) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); if new_fetch == scan.fetch { - original_limit(skip, fetch, LogicalPlan::TableScan(scan)) + original_limit(skip, fetch, LogicalPlan::table_scan(scan)) } else { // push limit into the table scan itself scan.fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); - transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) + transformed_limit(skip, fetch, LogicalPlan::table_scan(scan)) } } - LogicalPlan::Union(mut union) => { + LogicalPlan::Union(mut union, _) => { // push limits to each input of the union union.inputs = union .inputs .into_iter() .map(|input| make_arc_limit(0, fetch + skip, input)) .collect(); - transformed_limit(skip, fetch, LogicalPlan::Union(union)) + transformed_limit(skip, fetch, LogicalPlan::union(union)) } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) + LogicalPlan::Join(join, _) => Ok(push_down_join(join, fetch + skip) .update_data(|join| { - make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + make_limit(skip, fetch, Arc::new(LogicalPlan::join(join))) })), - LogicalPlan::Sort(mut sort) => { + LogicalPlan::Sort(mut sort, _) => { let new_fetch = { let sort_fetch = skip + fetch; Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) }; if new_fetch == sort.fetch { if skip > 0 { - original_limit(skip, fetch, LogicalPlan::Sort(sort)) + original_limit(skip, fetch, LogicalPlan::sort(sort)) } else { - Ok(Transformed::yes(LogicalPlan::Sort(sort))) + Ok(Transformed::yes(LogicalPlan::sort(sort))) } } else { sort.fetch = new_fetch; - limit.input = Arc::new(LogicalPlan::Sort(sort)); - Ok(Transformed::yes(LogicalPlan::Limit(limit))) + limit.input = Arc::new(LogicalPlan::sort(sort)); + Ok(Transformed::yes(LogicalPlan::limit(limit))) } } - LogicalPlan::Projection(mut proj) => { + LogicalPlan::Projection(mut proj, _) => { // commute limit.input = Arc::clone(&proj.input); - let new_limit = LogicalPlan::Limit(limit); + let new_limit = LogicalPlan::limit(limit); proj.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::Projection(proj))) + Ok(Transformed::yes(LogicalPlan::projection(proj))) } - LogicalPlan::SubqueryAlias(mut subquery_alias) => { + LogicalPlan::SubqueryAlias(mut subquery_alias, _) => { // commute limit.input = Arc::clone(&subquery_alias.input); - let new_limit = LogicalPlan::Limit(limit); + let new_limit = LogicalPlan::limit(limit); subquery_alias.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) + Ok(Transformed::yes(LogicalPlan::subquery_alias( + subquery_alias, + ))) } - LogicalPlan::Extension(extension_plan) + LogicalPlan::Extension(extension_plan, _) if extension_plan.node.supports_limit_pushdown() => { let new_children = extension_plan @@ -162,7 +164,7 @@ impl OptimizerRule for PushDownLimit { .inputs() .into_iter() .map(|child| { - LogicalPlan::Limit(Limit { + LogicalPlan::limit(Limit { skip: None, fetch: Some(Box::new(lit((fetch + skip) as i64))), input: Arc::new(child.clone()), @@ -171,7 +173,7 @@ impl OptimizerRule for PushDownLimit { .collect::>(); // Create a new extension node with updated inputs - let child_plan = LogicalPlan::Extension(extension_plan); + let child_plan = LogicalPlan::extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; @@ -203,7 +205,7 @@ impl OptimizerRule for PushDownLimit { /// input /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { - LogicalPlan::Limit(Limit { + LogicalPlan::limit(Limit { skip: Some(Box::new(lit(skip as i64))), fetch: Some(Box::new(lit(fetch as i64))), input, @@ -400,7 +402,7 @@ mod test { #[test] fn limit_pushdown_basic() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -422,7 +424,7 @@ mod test { #[test] fn limit_pushdown_with_skip() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -444,7 +446,7 @@ mod test { #[test] fn limit_pushdown_multiple_limits() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -467,7 +469,7 @@ mod test { #[test] fn limit_pushdown_multiple_inputs() -> Result<()> { let table_scan = test_table_scan()?; - let noop_plan = LogicalPlan::Extension(Extension { + let noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], schema: Arc::clone(table_scan.schema()), @@ -491,7 +493,7 @@ mod test { #[test] fn limit_pushdown_disallowed_noop_plan() -> Result<()> { let table_scan = test_table_scan()?; - let no_limit_noop_plan = LogicalPlan::Extension(Extension { + let no_limit_noop_plan = LogicalPlan::extension(Extension { node: Arc::new(NoLimitNoopPlan { input: vec![table_scan.clone()], schema: Arc::clone(table_scan.schema()), diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f3e1673e72111..f2500a09104bb 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -77,7 +77,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct::All(input)) => { + LogicalPlan::Distinct(Distinct::All(input), _) => { let group_expr = expand_wildcard(input.schema(), &input, None)?; let field_count = input.schema().fields().len(); @@ -95,20 +95,23 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { } // Replace with aggregation: - let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( + let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( input, group_expr, vec![], )?); Ok(Transformed::yes(aggr_plan)) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - select_expr, - on_expr, - sort_expr, - input, - schema, - })) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + }), + _, + ) => { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. @@ -131,7 +134,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let group_expr = normalize_cols(on_expr, input.as_ref())?; // Build the aggregation plan - let plan = LogicalPlan::Aggregate(Aggregate::try_new( + let plan = LogicalPlan::aggregate(Aggregate::try_new( input, group_expr, aggr_expr, )?); // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 2e2c8fb1d6f8c..f9b247fc9a982 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -18,6 +18,7 @@ //! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s use std::collections::{BTreeSet, HashMap}; +use std::ops::Not; use std::sync::Arc; use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; @@ -79,11 +80,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries if !contains_scalar_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( @@ -119,7 +120,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) @@ -127,11 +128,11 @@ impl OptimizerRule for ScalarSubqueryToJoin { .build()?; Ok(Transformed::yes(new_plan)) } - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries if !projection.expr.iter().any(contains_scalar_subquery) { - return Ok(Transformed::no(LogicalPlan::Projection(projection))); + return Ok(Transformed::no(LogicalPlan::projection(projection))); } let mut all_subqueryies = vec![]; @@ -182,7 +183,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::Projection(projection))); + return Ok(Transformed::no(LogicalPlan::projection(projection))); } } @@ -219,7 +220,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// Returns true if the expression has a scalar subquery somewhere in it /// false otherwise fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_, _)))) .expect("Inner is always Ok") } @@ -233,7 +234,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { + Expr::ScalarSubquery(subquery, _) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); self.sub_query_info .push((subquery.clone(), subqry_alias.clone())); @@ -242,7 +243,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; Ok(Transformed::new( - Expr::Column(create_col_from_scalar_expr( + Expr::column(create_col_from_scalar_expr( &scalar_expr, subqry_alias, )?), @@ -324,10 +325,13 @@ fn build_join( // join our sub query into the main plan let new_plan = if join_filter_opt.is_none() { match filter_input { - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: true, - schema: _, - }) => sub_query_alias, + LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: true, + schema: _, + }, + _, + ) => sub_query_alias, _ => { // if not correlated, group down to 1 row and left join on that (preserving row count) LogicalPlanBuilder::from(filter_input.clone()) @@ -345,34 +349,40 @@ fn build_join( if let Some(expr_map) = collected_count_expr_map { for (name, result) in expr_map { let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr { - Expr::Case(expr::Case { + Expr::case(expr::Case { expr: None, when_then_expr: vec![ ( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), + Box::new( + Expr::column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR, + )) + .is_null(), + ), Box::new(result), ), ( - Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(filter.clone().not()), + Box::new(Expr::literal(ScalarValue::Null)), ), ], - else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( + else_expr: Some(Box::new(Expr::column(Column::new_unqualified( name.clone(), )))), }) } else { - Expr::Case(expr::Case { + Expr::case(expr::Case { expr: None, when_then_expr: vec![( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), + Box::new( + Expr::column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR, + )) + .is_null(), + ), Box::new(result), )], - else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( + else_expr: Some(Box::new(Expr::column(Column::new_unqualified( name.clone(), )))), }) @@ -1038,7 +1048,7 @@ mod tests { .build()?, ); - let between_expr = Expr::Between(Between { + let between_expr = Expr::_between(Between { expr: Box::new(col("customer.c_custkey")), negated: false, low: Box::new(scalar_subquery(sq1)), @@ -1087,7 +1097,7 @@ mod tests { .build()?, ); - let between_expr = Expr::Between(Between { + let between_expr = Expr::_between(Between { expr: Box::new(col("customer.c_custkey")), negated: false, low: Box::new(scalar_subquery(sq1)), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6564e722eaf89..e80c1365a8f01 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -43,17 +43,18 @@ use datafusion_expr::{ utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use enumset::enum_set; use indexmap::IndexSet; +use super::inlist_simplifier::ShortenInListSimplifier; +use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use regex::Regex; -use super::inlist_simplifier::ShortenInListSimplifier; -use super::utils::*; - /// This structure handles API for expression simplification /// /// Provides simplification information based on DFSchema and @@ -410,41 +411,63 @@ impl ExprSimplifier { /// /// ` ` is rewritten so that the name of `col1` sorts higher /// than `col2` (`a > b` would be canonicalized to `b < a`) -struct Canonicalizer {} +struct Canonicalizer { + skip: bool, +} impl Canonicalizer { fn new() -> Self { - Self {} + Self { skip: false } } } impl TreeNodeRewriter for Canonicalizer { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !(node + .stats() + .contains_pattern(LogicalPlanPattern::ExprBinaryExpr) + && node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprColumn | LogicalPlanPattern::ExprLiteral + ))) + { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { - let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + + let Expr::BinaryExpr(BinaryExpr { left, op, right }, _) = expr else { return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { // - (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) + (Expr::Column(left_col, _), Expr::Column(right_col, _), Some(swapped_op)) if right_col > left_col => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr { left: right, op: swapped_op, right: left, }))) } // - (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + (Expr::Literal(_a, _), Expr::Column(_b, _), Some(swapped_op)) => { + Ok(Transformed::yes(Expr::binary_expr(BinaryExpr { left: right, op: swapped_op, right: left, }))) } - _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + _ => Ok(Transformed::no(Expr::binary_expr(BinaryExpr { left, op, right, @@ -529,10 +552,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { let result = self.evaluate_to_scalar(expr); match result { ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) + Ok(Transformed::yes(Expr::literal(s))) } ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) + Ok(Transformed::no(Expr::literal(s))) } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) @@ -589,36 +612,36 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::ScalarVariable(_, _) - | Expr::Column(_) - | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _, _) + | Expr::Column(_, _) + | Expr::OuterReferenceColumn(_, _, _) | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) + | Expr::InSubquery(_, _) + | Expr::ScalarSubquery(_, _) | Expr::WindowFunction { .. } - | Expr::GroupingSet(_) + | Expr::GroupingSet(_, _) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { func, .. }) => { + | Expr::Placeholder(_, _) => false, + Expr::ScalarFunction(ScalarFunction { func, .. }, _) => { Self::volatility_ok(func.signature().volatility) } - Expr::Literal(_) - | Expr::Unnest(_) + Expr::Literal(_, _) + | Expr::Unnest(_, _) | Expr::BinaryExpr { .. } - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) + | Expr::Not(_, _) + | Expr::IsNotNull(_, _) + | Expr::IsNull(_, _) + | Expr::IsTrue(_, _) + | Expr::IsFalse(_, _) + | Expr::IsUnknown(_, _) + | Expr::IsNotTrue(_, _) + | Expr::IsNotFalse(_, _) + | Expr::IsNotUnknown(_, _) + | Expr::Negative(_, _) | Expr::Between { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } - | Expr::Case(_) + | Expr::Case(_, _) | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::InList { .. } => true, @@ -627,7 +650,7 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { - if let Expr::Literal(s) = expr { + if let Expr::Literal(s, _) = expr { return ConstSimplifyResult::NotSimplified(s); } @@ -713,6 +736,27 @@ impl<'a, S> Simplifier<'a, S> { impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr + | LogicalPlanPattern::ExprNot + | LogicalPlanPattern::ExprNegative + | LogicalPlanPattern::ExprCase + | LogicalPlanPattern::ExprScalarFunction + | LogicalPlanPattern::ExprAggregateFunction + | LogicalPlanPattern::ExprWindowFunction + | LogicalPlanPattern::ExprBetween + | LogicalPlanPattern::ExprIsNotNull + | LogicalPlanPattern::ExprIsNull + | LogicalPlanPattern::ExprInList + | LogicalPlanPattern::ExprLike + )) { + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ @@ -730,28 +774,34 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // true = A --> A // false = A --> !A // null = A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Eq, - right, - }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Eq, + right, + }, + _, + ) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { Transformed::yes(match as_bool_lit(&left)? { Some(true) => *right, - Some(false) => Expr::Not(right), + Some(false) => Expr::_not(right), None => lit_bool_null(), }) } // A = true --> A // A = false --> !A // A = null --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Eq, - right, - }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Eq, + right, + }, + _, + ) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { Transformed::yes(match as_bool_lit(&right)? { Some(true) => *left, - Some(false) => Expr::Not(left), + Some(false) => Expr::_not(left), None => lit_bool_null(), }) } @@ -761,13 +811,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // true != A --> !A // false != A --> A // null != A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: NotEq, - right, - }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: NotEq, + right, + }, + _, + ) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { Transformed::yes(match as_bool_lit(&left)? { - Some(true) => Expr::Not(right), + Some(true) => Expr::_not(right), Some(false) => *right, None => lit_bool_null(), }) @@ -775,13 +828,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // A != true --> !A // A != false --> A // A != null --> null, - Expr::BinaryExpr(BinaryExpr { - left, - op: NotEq, - right, - }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: NotEq, + right, + }, + _, + ) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { Transformed::yes(match as_bool_lit(&right)? { - Some(true) => Expr::Not(left), + Some(true) => Expr::_not(left), Some(false) => *left, None => lit_bool_null(), }) @@ -792,76 +848,109 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // true OR A --> true (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right: _, - }) if is_true(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right: _, + }, + _, + ) if is_true(&left) => Transformed::yes(*left), // false OR A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_false(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Or, - right, - }) if is_true(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Or, + right, + }, + _, + ) if is_true(&right) => Transformed::yes(*right), // A OR false --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_false(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_not_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_not_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(lit(true)) } // !A OR A ---> true (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_not_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_not_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(lit(true)) } // (..A..) OR A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_op_with(And, &right, &left) => Transformed::yes(*left), // (A AND B) OR A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if is_op_with(And, &left, &right) => Transformed::yes(*right), // Eliminate common factors in conjunctions e.g // (A AND B) OR (A AND C) -> A AND (B OR C) - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if has_common_conjunction(&left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if has_common_conjunction(&left, &right) => { let lhs: IndexSet = iter_conjunction_owned(*left).collect(); let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right) .partition(|e| lhs.contains(e) && !e.is_volatile()); @@ -882,116 +971,164 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // true AND A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_true(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right: _, - }) if is_false(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right: _, + }, + _, + ) if is_false(&left) => Transformed::yes(*left), // A AND true --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_true(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left: _, - op: And, - right, - }) if is_false(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: And, + right, + }, + _, + ) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_not_of(&right, &left) && !info.nullable(&left)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_not_of(&right, &left) && !info.nullable(&left)? => { Transformed::yes(lit(false)) } // !A AND A ---> false (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_not_of(&left, &right) && !info.nullable(&right)? => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_not_of(&left, &right) && !info.nullable(&right)? => { Transformed::yes(lit(false)) } // (..A..) AND A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if expr_contains(&left, &right, And) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if expr_contains(&right, &left, And) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_op_with(Or, &right, &left) => Transformed::yes(*left), // (A OR B) AND A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if is_op_with(Or, &left, &right) => Transformed::yes(*right), // // Rules for Multiply // // A * 1 --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if is_one(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if is_one(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if is_one(&left) => Transformed::yes(*right), // A * null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Multiply, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Multiply, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null * A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if !info.nullable(&left)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if !info.nullable(&left)? && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right, - }) if !info.nullable(&right)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Multiply, + right, + }, + _, + ) if !info.nullable(&right)? && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { @@ -1003,50 +1140,68 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A / 1 --> A - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right, - }) if is_one(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Divide, + right, + }, + _, + ) if is_one(&right) => Transformed::yes(*left), // null / A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Divide, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A / null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Divide, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Divide, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo // // A % null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Modulo, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: Modulo, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null % A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Modulo, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Modulo, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) - Expr::BinaryExpr(BinaryExpr { - left, - op: Modulo, - right, - }) if !info.nullable(&left)? + Expr::BinaryExpr( + BinaryExpr { + left, + op: Modulo, + right, + }, + _, + ) if !info.nullable(&left)? && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( + Transformed::yes(Expr::literal(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } @@ -1056,84 +1211,114 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A & null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseAnd, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseAnd, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null & A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Transformed::yes(Expr::literal(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } // A & !A -> 0 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Transformed::yes(Expr::literal(ScalarValue::new_zero( &info.get_data_type(&left)?, )?)) } // (..A..) & A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { Transformed::yes(*left) } // (A | B) & A --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right, - }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseAnd, + right, + }, + _, + ) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { Transformed::yes(*right) } @@ -1142,84 +1327,114 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A | null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseOr, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseOr, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null | A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_zero(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Transformed::yes(Expr::literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // A | !A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Transformed::yes(Expr::literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // (..A..) | A --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { Transformed::yes(*left) } // (A & B) | A --> A (if B not null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right, - }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseOr, + right, + }, + _, + ) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { Transformed::yes(*right) } @@ -1228,78 +1443,102 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A ^ null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseXor, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseXor, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Transformed::yes(Expr::literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // A ^ !A -> -1 (if A not nullable) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Transformed::yes(Expr::literal(ScalarValue::new_negative_one( &info.get_data_type(&left)?, )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if expr_contains(&left, &right, BitwiseXor) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr }) } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right, - }) if expr_contains(&right, &left, BitwiseXor) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseXor, + right, + }, + _, + ) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr }) @@ -1310,60 +1549,78 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // A >> null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftRight, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseShiftRight, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null >> A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftRight, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftRight, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftRight, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftRight, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight // // A << null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftLeft, - right, - }) if is_null(&right) => Transformed::yes(*right), + Expr::BinaryExpr( + BinaryExpr { + left: _, + op: BitwiseShiftLeft, + right, + }, + _, + ) if is_null(&right) => Transformed::yes(*right), // null << A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftLeft, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftLeft, + right: _, + }, + _, + ) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftLeft, - right, - }) if is_zero(&right) => Transformed::yes(*left), + Expr::BinaryExpr( + BinaryExpr { + left, + op: BitwiseShiftLeft, + right, + }, + _, + ) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), + Expr::Not(inner, _) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), + Expr::Negative(inner, _) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1380,11 +1637,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Note: the rationale for this rewrite is that the expr can then be further // simplified using the existing rules for AND/OR - Expr::Case(Case { - expr: None, - when_then_expr, - else_expr, - }) if !when_then_expr.is_empty() + Expr::Case( + Case { + expr: None, + when_then_expr, + else_expr, + }, + _, + ) if !when_then_expr.is_empty() && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number && info.is_boolean_type(&when_then_expr[0].1)? => { @@ -1412,10 +1672,10 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Do a first pass at simplification out_expr.rewrite(self)? } - Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { + Expr::ScalarFunction(ScalarFunction { func: udf, args }, _) => { match udf.simplify(args, info)? { ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { + Transformed::no(Expr::scalar_function(ScalarFunction { func: udf, args, })) @@ -1424,21 +1684,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { - ref func, - .. - }) => match (func.simplify(), expr) { - (Some(simplify_function), Expr::AggregateFunction(af)) => { + Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction { ref func, .. }, + _, + ) => match (func.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af, _)) => { Transformed::yes(simplify_function(af, info)?) } (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { - (Some(simplify_function), Expr::WindowFunction(wf)) => { + Expr::WindowFunction( + WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(ref udwf), + .. + }, + _, + ) => match (udwf.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf, _)) => { Transformed::yes(simplify_function(wf, info)?) } (_, expr) => Transformed::no(expr), @@ -1450,7 +1713,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => Transformed::yes(if between.negated { + Expr::Between(between, _) => Transformed::yes(if between.negated { let l = *between.expr.clone(); let r = *between.expr; or(l.lt(*between.low), r.gt(*between.high)) @@ -1464,14 +1727,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Rules for regexes // - Expr::BinaryExpr(BinaryExpr { - left, - op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), - right, - }) => Transformed::yes(simplify_regex_expr(left, op, right)?), + Expr::BinaryExpr( + BinaryExpr { + left, + op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), + right, + }, + _, + ) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like - Expr::Like(like) => { + Expr::Like(like, _) => { // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 let escape_char = like.escape_char.unwrap_or('\\'); match as_string_scalar(&like.pattern) { @@ -1489,8 +1755,10 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Transformed::yes(if !info.nullable(&like.expr)? { result_for_non_null } else { - Expr::Case(Case { - expr: Some(Box::new(Expr::IsNotNull(like.expr))), + Expr::case(Case { + expr: Some(Box::new(Expr::_is_not_null( + like.expr, + ))), when_then_expr: vec![( Box::new(lit(true)), Box::new(result_for_non_null), @@ -1509,7 +1777,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { .unwrap() .replace_all(pattern_str, "%") .to_string(); - Transformed::yes(Expr::Like(Like { + Transformed::yes(Expr::_like(Like { pattern: Box::new(to_string_scalar( data_type, Some(simplified_pattern), @@ -1523,63 +1791,74 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression // TODO: handle escape characters - Transformed::yes(Expr::BinaryExpr(BinaryExpr { + Transformed::yes(Expr::binary_expr(BinaryExpr { left: like.expr.clone(), op: if like.negated { NotEq } else { Eq }, right: like.pattern.clone(), })) } - Some(_pattern_str) => Transformed::no(Expr::Like(like)), + Some(_pattern_str) => Transformed::no(Expr::_like(like)), } } - None => Transformed::no(Expr::Like(like)), + None => Transformed::no(Expr::_like(like)), } } // a is not null/unknown --> true (if a is not nullable) - Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) + Expr::IsNotNull(expr, _) | Expr::IsNotUnknown(expr, _) if !info.nullable(&expr)? => { Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) - Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { + Expr::IsNull(expr, _) | Expr::IsUnknown(expr, _) + if !info.nullable(&expr)? => + { Transformed::yes(lit(false)) } // expr IN () --> false // expr NOT IN () --> true - Expr::InList(InList { - expr, - list, - negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) if list.is_empty() && *expr != Expr::literal(ScalarValue::Null) => { Transformed::yes(lit(negated)) } // null in (x, y, z) --> null // null not in (x, y, z) --> null - Expr::InList(InList { - expr, - list: _, - negated: _, - }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), + Expr::InList( + InList { + expr, + list: _, + negated: _, + }, + _, + ) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), // expr IN ((subquery)) -> expr IN (subquery), see ##5529 - Expr::InList(InList { - expr, - mut list, - negated, - }) if list.len() == 1 + Expr::InList( + InList { + expr, + mut list, + negated, + }, + _, + ) if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => { - let Expr::ScalarSubquery(subquery) = list.remove(0) else { + let Expr::ScalarSubquery(subquery, _) = list.remove(0) else { unreachable!() }; - Transformed::yes(Expr::InSubquery(InSubquery::new( + Transformed::yes(Expr::in_subquery(InSubquery::new( expr, subquery, negated, ))) } @@ -1587,11 +1866,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Combine multiple OR expressions into a single IN list expression if possible // // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); @@ -1608,7 +1890,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - Transformed::yes(Expr::InList(merged_inlist)) + Transformed::yes(Expr::_in_list(merged_inlist)) } // Simplify expressions that is guaranteed to be true or false to a literal boolean expression @@ -1627,11 +1909,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), false, @@ -1639,7 +1924,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_intersection(l1, &l2, false).map(Transformed::yes); } // Matched previously once @@ -1647,11 +1932,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1659,7 +1947,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_union(l1, l2, true).map(Transformed::yes); } // Matched previously once @@ -1667,11 +1955,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), false, @@ -1679,7 +1970,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_except(l1, &l2).map(Transformed::yes); } // Matched previously once @@ -1687,11 +1978,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: And, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: And, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1699,7 +1993,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_except(l2, &l1).map(Transformed::yes); } // Matched previously once @@ -1707,11 +2001,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } - Expr::BinaryExpr(BinaryExpr { - left, - op: Or, - right, - }) if are_inlist_and_eq_and_match_neg( + Expr::BinaryExpr( + BinaryExpr { + left, + op: Or, + right, + }, + _, + ) if are_inlist_and_eq_and_match_neg( left.as_ref(), right.as_ref(), true, @@ -1719,7 +2016,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { match (*left, *right) { - (Expr::InList(l1), Expr::InList(l2)) => { + (Expr::InList(l1, _), Expr::InList(l2, _)) => { return inlist_intersection(l1, &l2, true).map(Transformed::yes); } // Matched previously once @@ -1735,18 +2032,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { - Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)), + Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), + Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), + Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), _ => None, } } fn to_string_scalar(data_type: DataType, value: Option) -> Expr { match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)), + DataType::Utf8 => Expr::literal(ScalarValue::Utf8(value)), + DataType::LargeUtf8 => Expr::literal(ScalarValue::LargeUtf8(value)), + DataType::Utf8View => Expr::literal(ScalarValue::Utf8View(value)), _ => unreachable!(), } } @@ -1764,7 +2061,7 @@ fn are_inlist_and_eq_and_match_neg( is_right_neg: bool, ) -> bool { match (left, right) { - (Expr::InList(l), Expr::InList(r)) => { + (Expr::InList(l, _), Expr::InList(r, _)) => { l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg } _ => false, @@ -1776,8 +2073,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_, _)) + && matches!(rhs.expr.as_ref(), Expr::Column(_, _)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -1789,15 +2086,15 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { /// Try to convert an expression to an in-list expression fn as_inlist(expr: &Expr) -> Option> { match expr { - Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { + Expr::InList(inlist, _) => Some(Cow::Borrowed(inlist)), + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + (Expr::Column(_, _), Expr::Literal(_, _)) => Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + (Expr::Literal(_, _), Expr::Column(_, _)) => Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], negated: false, @@ -1811,18 +2108,21 @@ fn as_inlist(expr: &Expr) -> Option> { fn to_inlist(expr: Expr) -> Option { match expr { - Expr::InList(inlist) => Some(inlist), - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(InList { + Expr::InList(inlist, _) => Some(inlist), + Expr::BinaryExpr( + BinaryExpr { + left, + op: Operator::Eq, + right, + }, + _, + ) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(_, _), Expr::Literal(_, _)) => Some(InList { expr: left, list: vec![*right], negated: false, }), - (Expr::Literal(_), Expr::Column(_)) => Some(InList { + (Expr::Literal(_, _), Expr::Column(_, _)) => Some(InList { expr: right, list: vec![*left], negated: false, @@ -1848,7 +2148,7 @@ fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { l1.list.extend(keep_l2); l1.negated = negated; - Ok(Expr::InList(l1)) + Ok(Expr::_in_list(l1)) } /// Return the intersection of two inlist expressions @@ -1864,7 +2164,7 @@ fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result Result { if l1.list.is_empty() { return Ok(lit(false)); } - Ok(Expr::InList(l1)) + Ok(Expr::_in_list(l1)) } #[cfg(test)] @@ -2134,7 +2434,7 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = Expr::literal(ScalarValue::Null); // A * null --> null { let expr = col("c2") * null.clone(); @@ -3025,7 +3325,7 @@ mod tests { } fn regex_match(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexMatch, right: Box::new(right), @@ -3033,7 +3333,7 @@ mod tests { } fn regex_not_match(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexNotMatch, right: Box::new(right), @@ -3041,7 +3341,7 @@ mod tests { } fn regex_imatch(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexIMatch, right: Box::new(right), @@ -3049,7 +3349,7 @@ mod tests { } fn regex_not_imatch(left: Expr, right: Expr) -> Expr { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(left), op: Operator::RegexNotIMatch, right: Box::new(right), @@ -3151,13 +3451,13 @@ mod tests { #[test] fn simplify_expr_is_not_null() { assert_eq!( - simplify(Expr::IsNotNull(Box::new(col("c1")))), - Expr::IsNotNull(Box::new(col("c1"))) + simplify(Expr::_is_not_null(Box::new(col("c1")))), + Expr::_is_not_null(Box::new(col("c1"))) ); // 'c1_non_null IS NOT NULL' is always true assert_eq!( - simplify(Expr::IsNotNull(Box::new(col("c1_non_null")))), + simplify(Expr::_is_not_null(Box::new(col("c1_non_null")))), lit(true) ); } @@ -3165,13 +3465,13 @@ mod tests { #[test] fn simplify_expr_is_null() { assert_eq!( - simplify(Expr::IsNull(Box::new(col("c1")))), - Expr::IsNull(Box::new(col("c1"))) + simplify(Expr::_is_null(Box::new(col("c1")))), + Expr::_is_null(Box::new(col("c1"))) ); // 'c1_non_null IS NULL' is always false assert_eq!( - simplify(Expr::IsNull(Box::new(col("c1_non_null")))), + simplify(Expr::_is_null(Box::new(col("c1_non_null")))), lit(false) ); } @@ -3269,7 +3569,7 @@ mod tests { // --> // false assert_eq!( - simplify(Expr::Case(Case::new( + simplify(Expr::case(Case::new( None, vec![( Box::new(col("c2").not_eq(lit(false))), @@ -3289,7 +3589,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![( Box::new(col("c2").not_eq(lit(false))), @@ -3307,7 +3607,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![(Box::new(col("c2").is_null()), Box::new(lit(true)),)], Some(Box::new(col("c2"))), @@ -3325,7 +3625,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![ (Box::new(col("c1")), Box::new(lit(true)),), @@ -3344,7 +3644,7 @@ mod tests { // Need to call simplify 2x due to // https://github.com/apache/datafusion/issues/1160 assert_eq!( - simplify(simplify(Expr::Case(Case::new( + simplify(simplify(Expr::case(Case::new( None, vec![ (Box::new(col("c1")), Box::new(lit(true)),), @@ -3963,7 +4263,7 @@ mod tests { fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Expr::aggregate_function(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3977,7 +4277,7 @@ mod tests { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Expr::aggregate_function(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -4058,7 +4358,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::window_function(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4067,7 +4367,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::window_function(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -4156,7 +4456,7 @@ mod tests { #[test] fn test_optimize_volatile_conditions() { let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new())); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + let rand = Expr::scalar_function(ScalarFunction::new_udf(fun, vec![])); { let expr = rand .clone() @@ -4191,7 +4491,7 @@ mod tests { } fn if_not_null(expr: Expr, then: bool) -> Expr { - Expr::Case(Case { + Expr::case(Case { expr: Some(expr.is_not_null().into()), when_then_expr: vec![(lit(true).into(), lit(then).into())], else_expr: None, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index afcbe528083b8..0887b82d8b569 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -19,12 +19,13 @@ //! //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees -use std::{borrow::Cow, collections::HashMap}; - use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; +use enumset::enum_set; +use std::{borrow::Cow, collections::HashMap}; /// Rewrite expressions to incorporate guarantees. /// @@ -41,6 +42,7 @@ use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; /// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees pub struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, + skip: bool, } impl<'a> GuaranteeRewriter<'a> { @@ -53,6 +55,7 @@ impl<'a> GuaranteeRewriter<'a> { // issue is fixed. #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + skip: false, } } } @@ -60,31 +63,55 @@ impl<'a> GuaranteeRewriter<'a> { impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprIsNull + | LogicalPlanPattern::ExprIsNotNull + | LogicalPlanPattern::ExprBetween + | LogicalPlanPattern::ExprBinaryExpr + | LogicalPlanPattern::ExprColumn + | LogicalPlanPattern::ExprInList + )) { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + if self.guarantees.is_empty() { return Ok(Transformed::no(expr)); } match &expr { - Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { + Expr::IsNull(inner, _) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), Some(NullableInterval::NotNull { .. }) => { Ok(Transformed::yes(lit(false))) } _ => Ok(Transformed::no(expr)), }, - Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { + Expr::IsNotNull(inner, _) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), _ => Ok(Transformed::no(expr)), }, - Expr::Between(Between { - expr: inner, - negated, - low, - high, - }) => { - if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + Expr::Between( + Between { + expr: inner, + negated, + low, + high, + }, + _, + ) => { + if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), @@ -107,7 +134,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // The left or right side of expression might either have a guarantee // or be a literal. Either way, we can resolve them to a NullableInterval. let left_interval = self @@ -115,7 +142,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .get(left.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = left.as_ref() { + if let Expr::Literal(value, _) = left.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -126,7 +153,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .get(right.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = right.as_ref() { + if let Expr::Literal(value, _) = right.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -150,7 +177,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } // Columns (if interval is collapsed to a single value) - Expr::Column(_) => { + Expr::Column(_, _) => { if let Some(interval) = self.guarantees.get(&expr) { Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) } else { @@ -158,17 +185,20 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - Expr::InList(InList { - expr: inner, - list, - negated, - }) => { + Expr::InList( + InList { + expr: inner, + list, + negated, + }, + _, + ) => { if let Some(interval) = self.guarantees.get(inner.as_ref()) { // Can remove items from the list that don't match the guarantee let new_list: Vec = list .iter() .filter_map(|expr| { - if let Expr::Literal(item) = expr { + if let Expr::Literal(item, _) = expr { match interval .contains(NullableInterval::from(item.clone())) { @@ -184,7 +214,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Transformed::yes(Expr::InList(InList { + Ok(Transformed::yes(Expr::_in_list(InList { expr: inner.clone(), list: new_list, negated: *negated, @@ -301,7 +331,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -309,7 +339,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Date32(Some(17000)))), @@ -360,7 +390,7 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit("z")), @@ -368,7 +398,7 @@ mod tests { true, ), ( - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsNotDistinctFrom, right: Box::new(lit("z")), @@ -388,7 +418,7 @@ mod tests { col("x").not_eq(lit("a")), col("x").between(lit("a"), lit("z")), col("x").not_between(lit("a"), lit("z")), - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(col("x")), op: Operator::IsDistinctFrom, right: Box::new(lit(ScalarValue::Null)), @@ -417,7 +447,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::literal(scalar.clone())); } } @@ -467,7 +497,7 @@ mod tests { .collect(); assert_eq!( output, - Expr::InList(InList { + Expr::_in_list(InList { expr: Box::new(col(*column_name)), list: expected_list, negated: *negated, diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index c8638eb723955..1fa04040b638d 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -22,27 +22,50 @@ use super::THRESHOLD_INLINE_INLIST; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::Expr; -pub(super) struct ShortenInListSimplifier {} +pub(super) struct ShortenInListSimplifier { + skip: bool, +} impl ShortenInListSimplifier { pub(super) fn new() -> Self { - Self {} + Self { skip: false } } } impl TreeNodeRewriter for ShortenInListSimplifier { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node + .stats() + .contains_pattern(LogicalPlanPattern::ExprInList) + { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) - if let Expr::InList(InList { - expr, - list, - negated, - }) = expr.clone() + if let Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) = expr.clone() { if !list.is_empty() && ( diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 6c99f18ab0f64..a1dd68d168588 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -42,7 +42,7 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { match regex_syntax::Parser::new().parse(pattern) { Ok(hir) => { let kind = hir.kind(); @@ -67,7 +67,7 @@ pub fn simplify_regex_expr( } // Leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Expr::binary_expr(BinaryExpr { left, op, right })) } #[derive(Debug)] @@ -100,12 +100,12 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::literal(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; - Expr::Like(like) + Expr::_like(like) } /// Creates an [`Expr::BinaryExpr`] of "`left` = `right`" or "`left` != `right`". @@ -115,7 +115,7 @@ impl OperatorMode { } else { Operator::Eq }; - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::binary_expr(BinaryExpr { left, op, right }) } } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 200f1f159d813..0a05e8ce9ff97 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -17,14 +17,16 @@ //! Simplify expressions optimizer rule and implementation -use std::sync::Arc; - use datafusion_common::tree_node::Transformed; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; +use enumset::enum_set; +use std::cell::Cell; +use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; @@ -54,7 +56,7 @@ impl OptimizerRule for SimplifyExpressions { } fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) + None } fn supports_rewrite(&self) -> bool { @@ -79,58 +81,89 @@ impl SimplifyExpressions { plan: LogicalPlan, execution_props: &ExecutionProps, ) -> Result> { - let schema = if !plan.inputs().is_empty() { - DFSchemaRef::new(merge_schema(&plan.inputs())) - } else if let LogicalPlan::TableScan(scan) = &plan { - // When predicates are pushed into a table scan, there is no input - // schema to resolve predicates against, so it must be handled specially - // - // Note that this is not `plan.schema()` which is the *output* - // schema, and reflects any pushed down projection. The output schema - // will not contain columns that *only* appear in pushed down predicates - // (and no where else) in the plan. - // - // Thus, use the full schema of the inner provider without any - // projection applied for simplification - Arc::new(DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - &scan.source.schema(), - )?) - } else { - Arc::new(DFSchema::empty()) - }; - - let info = SimplifyContext::new(execution_props).with_schema(schema); - - // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) - // Just need to rewrite our own expressions - - let simplifier = ExprSimplifier::new(info); - - // The left and right expressions in a Join on clause are not - // commutative, for reasons that are not entirely clear. Thus, do not - // reorder expressions in Join while simplifying. - // - // This is likely related to the fact that order of the columns must - // match the order of the children. see - // https://github.com/apache/datafusion/pull/8780 for more details - let simplifier = if let LogicalPlan::Join(_) = plan { - simplifier.with_canonicalize(false) - } else { - simplifier - }; - - // Preserve expression names to avoid changing the schema of the plan. - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|e| { - let original_name = name_preserver.save(&e); - let new_e = simplifier - .simplify(e) - .map(|expr| original_name.restore(expr))?; - // TODO it would be nice to have a way to know if the expression was simplified - // or not. For now conservatively return Transformed::yes - Ok(Transformed::yes(new_e)) - }) + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr + | LogicalPlanPattern::ExprNot + | LogicalPlanPattern::ExprNegative + | LogicalPlanPattern::ExprCase + | LogicalPlanPattern::ExprScalarFunction + | LogicalPlanPattern::ExprAggregateFunction + | LogicalPlanPattern::ExprWindowFunction + | LogicalPlanPattern::ExprBetween + | LogicalPlanPattern::ExprIsNotNull + | LogicalPlanPattern::ExprIsNull + | LogicalPlanPattern::ExprInList + | LogicalPlanPattern::ExprLike + )) { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + let schema = if !plan.inputs().is_empty() { + DFSchemaRef::new(merge_schema(&plan.inputs())) + } else if let LogicalPlan::TableScan(scan, _) = &plan { + // When predicates are pushed into a table scan, there is no input + // schema to resolve predicates against, so it must be handled specially + // + // Note that this is not `plan.schema()` which is the *output* + // schema, and reflects any pushed down projection. The output schema + // will not contain columns that *only* appear in pushed down predicates + // (and no where else) in the plan. + // + // Thus, use the full schema of the inner provider without any + // projection applied for simplification + Arc::new(DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &scan.source.schema(), + )?) + } else { + Arc::new(DFSchema::empty()) + }; + + let info = SimplifyContext::new(execution_props).with_schema(schema); + + // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) + // Just need to rewrite our own expressions + + let simplifier = ExprSimplifier::new(info); + + // The left and right expressions in a Join on clause are not + // commutative, for reasons that are not entirely clear. Thus, do not + // reorder expressions in Join while simplifying. + // + // This is likely related to the fact that order of the columns must + // match the order of the children. see + // https://github.com/apache/datafusion/pull/8780 for more details + let simplifier = if let LogicalPlan::Join(_, _) = plan { + simplifier.with_canonicalize(false) + } else { + simplifier + }; + + // Preserve expression names to avoid changing the schema of the plan. + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|e| { + let original_name = name_preserver.save(&e); + let new_e = simplifier + .simplify(e) + .map(|expr| original_name.restore(expr))?; + // TODO it would be nice to have a way to know if the expression was simplified + // or not. For now conservatively return Transformed::yes + Ok(Transformed::yes(new_e)) + }) + }, + ) } } @@ -406,12 +439,12 @@ mod tests { #[test] fn test_simplify_optimized_plan_support_values() -> Result<()> { - let expr1 = Expr::BinaryExpr(BinaryExpr::new( + let expr1 = Expr::binary_expr(BinaryExpr::new( Box::new(lit(1)), Operator::Plus, Box::new(lit(2)), )); - let expr2 = Expr::BinaryExpr(BinaryExpr::new( + let expr2 = Expr::binary_expr(BinaryExpr::new( Box::new(lit(2)), Operator::Minus, Box::new(lit(1)), @@ -439,7 +472,7 @@ mod tests { #[test] fn cast_expr() -> Result<()> { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; + let proj = vec![Expr::cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj)? .build()?; diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index c30c3631c193a..bfc0c1af602fc 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,6 +23,7 @@ use datafusion_expr::{ expr_fn::{and, bitwise_and, bitwise_or, or}, Expr, Like, Operator, }; +use std::ops::Not; pub static POWS_OF_TEN: [i128; 38] = [ 1, @@ -69,7 +70,7 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// expressions. Such as: (A AND B) AND C fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == search_op => { expr_contains_inner(left, needle, search_op) || expr_contains_inner(right, needle, search_op) } @@ -92,7 +93,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> xor_counter: &mut i32, ) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if *op == Operator::BitwiseXor => { let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter); @@ -105,7 +106,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> return left_expr; } - Expr::BinaryExpr(BinaryExpr::new( + Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), *op, Box::new(right_expr), @@ -121,13 +122,13 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> return needle.clone(); } else if xor_counter % 2 == 0 { if is_left { - return Expr::BinaryExpr(BinaryExpr::new( + return Expr::binary_expr(BinaryExpr::new( Box::new(needle.clone()), Operator::BitwiseXor, Box::new(result_expr), )); } else { - return Expr::BinaryExpr(BinaryExpr::new( + return Expr::binary_expr(BinaryExpr::new( Box::new(result_expr), Operator::BitwiseXor, Box::new(needle.clone()), @@ -139,34 +140,34 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(ScalarValue::Int8(Some(0)), _) + | Expr::Literal(ScalarValue::Int16(Some(0)), _) + | Expr::Literal(ScalarValue::Int32(Some(0)), _) + | Expr::Literal(ScalarValue::Int64(Some(0)), _) + | Expr::Literal(ScalarValue::UInt8(Some(0)), _) + | Expr::Literal(ScalarValue::UInt16(Some(0)), _) + | Expr::Literal(ScalarValue::UInt32(Some(0)), _) + | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { + Expr::Literal(ScalarValue::Int8(Some(1)), _) + | Expr::Literal(ScalarValue::Int16(Some(1)), _) + | Expr::Literal(ScalarValue::Int32(Some(1)), _) + | Expr::Literal(ScalarValue::Int64(Some(1)), _) + | Expr::Literal(ScalarValue::UInt8(Some(1)), _) + | Expr::Literal(ScalarValue::UInt16(Some(1)), _) + | Expr::Literal(ScalarValue::UInt32(Some(1)), _) + | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => { *s >= 0 && POWS_OF_TEN .get(*s as usize) @@ -179,7 +180,7 @@ pub fn is_one(s: &Expr) -> bool { pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v, _ => false, } } @@ -187,48 +188,48 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _)) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::literal(ScalarValue::Boolean(None)) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v), _ => false, } } /// returns true if `haystack` looks like (needle OP X) or (X OP needle) pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) } /// returns true if `not_expr` is !`expr` (not) pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool { - matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref()) + matches!(not_expr, Expr::Not(inner, _) if expr == inner.as_ref()) } /// returns true if `not_expr` is !`expr` (bitwise not) pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { - matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref()) + matches!(not_expr, Expr::Negative(inner, _) if expr == inner.as_ref()) } /// returns the contained boolean value in `expr` as /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } @@ -249,9 +250,9 @@ pub fn as_bool_lit(expr: &Expr) -> Result> { /// For others, use Not clause pub fn negate_clause(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { if let Some(negated_op) = op.negate() { - return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right)); + return Expr::binary_expr(BinaryExpr::new(left, negated_op, right)); } match op { // not (A and B) ===> (not A) or (not B) @@ -269,34 +270,35 @@ pub fn negate_clause(expr: Expr) -> Expr { and(left, right) } // use not clause - _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new( - left, op, right, - )))), + _ => Expr::binary_expr(BinaryExpr::new(left, op, right)).not(), } } // not (not A) ===> A - Expr::Not(expr) => *expr, + Expr::Not(expr, _) => *expr, // not (A is not null) ===> A is null - Expr::IsNotNull(expr) => expr.is_null(), + Expr::IsNotNull(expr, _) => expr.is_null(), // not (A is null) ===> A is not null - Expr::IsNull(expr) => expr.is_not_null(), + Expr::IsNull(expr, _) => expr.is_not_null(), // not (A not in (..)) ===> A in (..) // not (A in (..)) ===> A not in (..) - Expr::InList(InList { - expr, - list, - negated, - }) => expr.in_list(list, !negated), + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => expr.in_list(list, !negated), // not (A between B and C) ===> (A not between B and C) // not (A not between B and C) ===> (A between B and C) - Expr::Between(between) => Expr::Between(Between::new( + Expr::Between(between, _) => Expr::_between(Between::new( between.expr, !between.negated, between.low, between.high, )), // not (A like B) ===> A not like B - Expr::Like(like) => Expr::Like(Like::new( + Expr::Like(like, _) => Expr::_like(Like::new( !like.negated, like.expr, like.pattern, @@ -304,7 +306,7 @@ pub fn negate_clause(expr: Expr) -> Expr { like.case_insensitive, )), // use not clause - _ => Expr::Not(Box::new(expr)), + _ => expr.not(), } } @@ -318,7 +320,7 @@ pub fn negate_clause(expr: Expr) -> Expr { /// For others, use Negative clause pub fn distribute_negation(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { match op { // ~(A & B) ===> ~A | ~B Operator::BitwiseAnd => { @@ -335,14 +337,14 @@ pub fn distribute_negation(expr: Expr) -> Expr { bitwise_and(left, right) } // use negative clause - _ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new( + _ => Expr::negative(Box::new(Expr::binary_expr(BinaryExpr::new( left, op, right, )))), } } // ~(~A) ===> A - Expr::Negative(expr) => *expr, + Expr::Negative(expr, _) => *expr, // use negative clause - _ => Expr::Negative(Box::new(expr)), + _ => Expr::negative(Box::new(expr)), } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c8f3a4bc7859c..1cefe352fa939 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -66,14 +66,17 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; for expr in aggr_expr { - if let Expr::AggregateFunction(AggregateFunction { - func, - distinct, - args, - filter, - order_by, - null_treatment: _, - }) = expr + if let Expr::AggregateFunction( + AggregateFunction { + func, + distinct, + args, + filter, + order_by, + null_treatment: _, + }, + _, + ) = expr { if filter.is_some() || order_by.is_some() { return Ok(false); @@ -98,7 +101,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { /// Check if the first expr is [Expr::GroupingSet]. fn contains_grouping_set(expr: &[Expr]) -> bool { - matches!(expr.first(), Some(Expr::GroupingSet(_))) + matches!(expr.first(), Some(Expr::GroupingSet(_, _))) } impl OptimizerRule for SingleDistinctToGroupBy { @@ -120,13 +123,16 @@ impl OptimizerRule for SingleDistinctToGroupBy { _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { match plan { - LogicalPlan::Aggregate(Aggregate { - input, - aggr_expr, - schema, - group_expr, - .. - }) if is_single_distinct_agg(&aggr_expr)? + LogicalPlan::Aggregate( + Aggregate { + input, + aggr_expr, + schema, + group_expr, + .. + }, + _, + ) if is_single_distinct_agg(&aggr_expr)? && !contains_grouping_set(&group_expr) => { let group_size = group_expr.len(); @@ -138,7 +144,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .into_iter() .enumerate() .map(|(i, group_expr)| { - if let Expr::Column(_) = group_expr { + if let Expr::Column(_, _) = group_expr { // For Column expressions we can use existing expression as is. (group_expr.clone(), (group_expr, None)) } else { @@ -182,7 +188,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mut args, distinct, .. - }) => { + }, _) => { if distinct { if args.len() != 1 { return internal_err!("DISTINCT aggregate should have exactly one argument"); @@ -193,7 +199,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { inner_group_exprs .push(arg.alias(SINGLE_DISTINCT_ALIAS)); } - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here @@ -206,7 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { index += 1; let alias_str = format!("alias{}", index); inner_aggr_exprs.push( - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( Arc::clone(&func), args, false, @@ -216,7 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { )) .alias(&alias_str), ); - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::aggregate_function(AggregateFunction::new_udf( func, vec![col(&alias_str)], false, @@ -231,7 +237,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>>()?; // construct the inner AggrPlan - let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( + let inner_agg = LogicalPlan::aggregate(Aggregate::try_new( input, inner_group_exprs, inner_aggr_exprs, @@ -263,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { )) .collect(); - let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + let outer_aggr = LogicalPlan::aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, outer_aggr_exprs, @@ -288,7 +294,7 @@ mod tests { use datafusion_functions_aggregate::sum::sum_udaf; fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new_udf( + Expr::aggregate_function(AggregateFunction::new_udf( max_udaf(), vec![expr], true, @@ -345,7 +351,7 @@ mod tests { fn single_distinct_and_grouping_set() -> Result<()> { let table_scan = test_table_scan()?; - let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let grouping_set = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], ])); @@ -366,7 +372,8 @@ mod tests { fn single_distinct_and_cube() -> Result<()> { let table_scan = test_table_scan()?; - let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + let grouping_set = + Expr::grouping_set(GroupingSet::Cube(vec![col("a"), col("b")])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? @@ -385,7 +392,7 @@ mod tests { let table_scan = test_table_scan()?; let grouping_set = - Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + Expr::grouping_set(GroupingSet::Rollup(vec![col("a"), col("b")])); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? @@ -569,7 +576,7 @@ mod tests { let table_scan = test_table_scan()?; // sum(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + let expr = Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, @@ -612,7 +619,7 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + let expr = Expr::aggregate_function(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index a39f90b5da5db..94875ae17fb92 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -30,7 +30,7 @@ use std::{ /// Create a new user defined plan node, for testing pub fn new(input: LogicalPlan) -> LogicalPlan { let node = Arc::new(TestUserDefinedPlanNode { input }); - LogicalPlan::Extension(Extension { node }) + LogicalPlan::extension(Extension { node }) } #[derive(PartialEq, Eq, PartialOrd, Hash)] diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 31e21d08b569a..14bd6ee2359bd 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -17,6 +17,7 @@ //! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` +use std::cell::Cell; use std::cmp::Ordering; use std::mem; use std::sync::Arc; @@ -32,8 +33,10 @@ use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::merge_schema; use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; +use enumset::enum_set; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -87,7 +90,7 @@ impl OptimizerRule for UnwrapCastInComparison { } fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) + None } fn supports_rewrite(&self) -> bool { @@ -99,28 +102,54 @@ impl OptimizerRule for UnwrapCastInComparison { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let mut schema = merge_schema(&plan.inputs()); - - if let LogicalPlan::TableScan(ts) = &plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); - } + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !(plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr | LogicalPlanPattern::ExprInList + )) && plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprTryCast | LogicalPlanPattern::ExprCast + )) && plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprLiteral)) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } - schema.merge(plan.schema()); + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } - let mut expr_rewriter = UnwrapCastExprRewriter { - schema: Arc::new(schema), - }; + let mut schema = merge_schema(&plan.inputs()); + + if let LogicalPlan::TableScan(ts, _) = &plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); + } + + schema.merge(plan.schema()); + + let mut expr_rewriter = UnwrapCastExprRewriter { + schema: Arc::new(schema), + }; - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewriter) - .map(|transformed| transformed.update_data(|e| original_name.restore(e))) - }) + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr); + expr.rewrite(&mut expr_rewriter).map(|transformed| { + transformed.update_data(|e| original_name.restore(e)) + }) + }) + }, + ) } } @@ -131,12 +160,27 @@ struct UnwrapCastExprRewriter { impl TreeNodeRewriter for UnwrapCastExprRewriter { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !(node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr | LogicalPlanPattern::ExprInList + )) && node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprTryCast | LogicalPlanPattern::ExprCast + )) && node + .stats() + .contains_pattern(LogicalPlanPattern::ExprLiteral)) + { + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, mut expr: Expr) -> Result> { match &mut expr { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) - Expr::BinaryExpr(BinaryExpr { left, op, right }) + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) if { let Ok(left_type) = left.get_type(&self.schema) else { return Ok(Transformed::no(expr)); @@ -151,13 +195,19 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { { match (left.as_mut(), right.as_mut()) { ( - Expr::Literal(left_lit_value), - Expr::TryCast(TryCast { - expr: right_expr, .. - }) - | Expr::Cast(Cast { - expr: right_expr, .. - }), + Expr::Literal(left_lit_value, _), + Expr::TryCast( + TryCast { + expr: right_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: right_expr, .. + }, + _, + ), ) => { // if the left_lit_value can be cast to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal @@ -173,21 +223,30 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { else { return Ok(Transformed::no(expr)); }; - **left = lit(value); // unwrap the cast/try_cast for the right expr - **right = mem::take(right_expr); + expr = Expr::binary_expr(BinaryExpr { + left: Box::new(lit(value)), + op: *op, + right: Box::new(mem::take(right_expr)), + }); Ok(Transformed::yes(expr)) } } } ( - Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - }), - Expr::Literal(right_lit_value), + Expr::TryCast( + TryCast { + expr: left_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: left_expr, .. + }, + _, + ), + Expr::Literal(right_lit_value, _), ) => { // if the right_lit_value can be cast to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal @@ -204,8 +263,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { return Ok(Transformed::no(expr)); }; // unwrap the cast/try_cast for the left expr - **left = mem::take(left_expr); - **right = lit(value); + expr = Expr::binary_expr(BinaryExpr { + left: mem::take(left_expr), + op: *op, + right: Box::new(lit(value)), + }); Ok(Transformed::yes(expr)) } } @@ -215,15 +277,26 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) - Expr::InList(InList { - expr: left, list, .. - }) => { - let (Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - })) = left.as_mut() + Expr::InList( + InList { + expr: left, + list, + negated, + }, + _, + ) => { + let (Expr::TryCast( + TryCast { + expr: left_expr, .. + }, + _, + ) + | Expr::Cast( + Cast { + expr: left_expr, .. + }, + _, + )) = left.as_mut() else { return Ok(Transformed::no(expr)); }; @@ -244,7 +317,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { )?; } match right { - Expr::Literal(right_lit_value) => { + Expr::Literal(right_lit_value, _) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(right_lit_value, &expr_type) else { @@ -264,8 +337,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { .collect::>>() else { return Ok(Transformed::no(expr)) }; - **left = mem::take(left_expr); - *list = right_exprs; + expr = Expr::_in_list(InList { + expr: Box::new(mem::take(left_expr)), + list: right_exprs, + negated: *negated, + }); Ok(Transformed::yes(expr)) } // TODO: handle other expr type and dfs visit them diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 9f325bc01b1d0..8cb85dc3300ce 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -132,7 +132,7 @@ pub fn is_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - if matches!(predicate, Expr::Column(_)) { + if matches!(predicate, Expr::Column(_, _)) { return Ok(true); } @@ -195,10 +195,10 @@ mod tests { // a IS NULL (is_null(col("a")), false), // a IS NOT NULL - (Expr::IsNotNull(Box::new(col("a"))), true), + (Expr::_is_not_null(Box::new(col("a"))), true), // a = NULL ( - binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + binary_expr(col("a"), Operator::Eq, Expr::literal(ScalarValue::Null)), true, ), // a > 8 @@ -261,12 +261,12 @@ mod tests { ), // a IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + in_list(col("a"), vec![Expr::literal(ScalarValue::Null)], false), true, ), // a NOT IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + in_list(col("a"), vec![Expr::literal(ScalarValue::Null)], true), true, ), ]; diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index f0d02eb605b26..b8f2b7aa2a81c 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -97,7 +97,7 @@ impl PhysicalExpr for Literal { /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { - Expr::Literal(v) => Arc::new(Literal::new(v)), + Expr::Literal(v, _) => Arc::new(Literal::new(v)), _ => unreachable!(), } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index add6c18b329cf..b8365f011838f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -111,15 +111,15 @@ pub fn create_physical_expr( let input_schema: &Schema = &input_dfschema.into(); match e { - Expr::Alias(Alias { expr, .. }) => { + Expr::Alias(Alias { expr, .. }, _) => { Ok(create_physical_expr(expr, input_dfschema, execution_props)?) } - Expr::Column(c) => { + Expr::Column(c, _) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - Expr::ScalarVariable(_, variable_names) => { + Expr::Literal(value, _) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::ScalarVariable(_, variable_names, _) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { Some(provider) => { @@ -138,7 +138,7 @@ pub fn create_physical_expr( } } } - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, @@ -146,12 +146,12 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotTrue(expr) => { + Expr::IsNotTrue(expr, _) => { let binary_op = binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(true)); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, @@ -159,28 +159,28 @@ pub fn create_physical_expr( ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { let binary_op = binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(false)); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::literal(ScalarValue::Boolean(None)), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::literal(ScalarValue::Boolean(None)), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // Create physical expressions for left and right operands let lhs = create_physical_expr(left, input_dfschema, execution_props)?; let rhs = create_physical_expr(right, input_dfschema, execution_props)?; @@ -193,13 +193,16 @@ pub fn create_physical_expr( // planning. binary(lhs, *op, rhs, input_schema) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { // `\` is the implicit escape, see https://github.com/apache/datafusion/issues/13291 if escape_char.unwrap_or('\\') != '\\' { return exec_err!( @@ -218,13 +221,16 @@ pub fn create_physical_expr( input_schema, ) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { if escape_char.is_some() { return exec_err!("SIMILAR TO does not support escape_char yet"); } @@ -234,7 +240,7 @@ pub fn create_physical_expr( create_physical_expr(pattern, input_dfschema, execution_props)?; similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) } - Expr::Case(case) => { + Expr::Case(case, _) => { let expr: Option> = if let Some(e) = &case.expr { Some(create_physical_expr( e.as_ref(), @@ -271,34 +277,34 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( + Expr::Cast(Cast { expr, data_type }, _) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( + Expr::TryCast(TryCast { expr, data_type }, _) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, data_type.clone(), ), - Expr::Not(expr) => { + Expr::Not(expr, _) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } - Expr::Negative(expr) => expressions::negative( + Expr::Negative(expr, _) => expressions::negative( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, ), - Expr::IsNull(expr) => expressions::is_null(create_physical_expr( + Expr::IsNull(expr, _) => expressions::is_null(create_physical_expr( expr, input_dfschema, execution_props, )?), - Expr::IsNotNull(expr) => expressions::is_not_null(create_physical_expr( + Expr::IsNotNull(expr, _) => expressions::is_not_null(create_physical_expr( expr, input_dfschema, execution_props, )?), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; @@ -310,12 +316,15 @@ pub fn create_physical_expr( input_dfschema, ) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let value_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let low_expr = create_physical_expr(low, input_dfschema, execution_props)?; let high_expr = create_physical_expr(high, input_dfschema, execution_props)?; @@ -344,12 +353,15 @@ pub fn create_physical_expr( binary_expr } } - Expr::InList(InList { - expr, - list, - negated, - }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => match expr.as_ref() { + Expr::Literal(ScalarValue::Utf8(None), _) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 3cfb03b7205a5..52d7edff42ec0 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -2549,7 +2549,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 in(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(col("c1")), vec![lit(1), lit(2), lit(3)], false, @@ -2580,7 +2580,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 in() - let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); + let expr = Expr::_in_list(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); @@ -2596,7 +2596,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ]); // test c1 not in(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(col("c1")), vec![lit(1), lit(2), lit(3)], true, @@ -2747,7 +2747,7 @@ mod tests { fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); // test cast(c1 as int64) in int64(1, 2, 3) - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(cast(col("c1"), DataType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), @@ -2772,7 +2772,7 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expr = Expr::InList(InList::new( + let expr = Expr::_in_list(InList::new( Box::new(cast(col("c1"), DataType::Int64)), vec![ lit(ScalarValue::Int64(Some(1))), @@ -3107,7 +3107,7 @@ mod tests { // -i < 0 prune_with_expr( - Expr::Negative(Box::new(col("i"))).lt(lit(0)), + Expr::negative(Box::new(col("i"))).lt(lit(0)), &schema, &statistics, expected_ret, @@ -3136,7 +3136,7 @@ mod tests { prune_with_expr( // -i >= 0 - Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)), + Expr::negative(Box::new(col("i"))).gt_eq(lit(0)), &schema, &statistics, expected_ret, @@ -3173,7 +3173,7 @@ mod tests { prune_with_expr( // cast(-i as utf8) >= 0 - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + cast(Expr::negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3181,7 +3181,7 @@ mod tests { prune_with_expr( // try_cast(-i as utf8) >= 0 - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + try_cast(Expr::negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3281,7 +3281,7 @@ mod tests { prune_with_expr( // -i < 1 - Expr::Negative(Box::new(col("i"))).lt(lit(1)), + Expr::negative(Box::new(col("i"))).lt(lit(1)), &schema, &statistics, expected_ret, @@ -3431,7 +3431,7 @@ mod tests { prune_with_expr( // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) + Expr::negative(Box::new(cast(col("i"), DataType::Int64))) .lt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 64d8e24ce5182..1446694edb2f0 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,7 +22,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::{Alias, Placeholder, Sort, Wildcard}; +use datafusion_expr::expr::{Placeholder, Sort, Wildcard}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ @@ -245,14 +245,18 @@ pub fn parse_expr( Ok(operands .into_iter() .reduce(|left, right| { - Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) + Expr::binary_expr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + )) }) .expect("Binary expression could not be reduced to a single expression.")) } - ExprType::Column(column) => Ok(Expr::Column(column.into())), + ExprType::Column(column) => Ok(Expr::column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::literal(scalar_value)) } ExprType::WindowExpr(expr) => { let window_function = expr @@ -284,7 +288,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::window_function(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -301,7 +305,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::window_function(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -313,64 +317,59 @@ pub fn parse_expr( } } } - ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, - alias - .relation - .first() - .map(|r| TableReference::try_from(r.clone())) - .transpose()?, - alias.alias.clone(), + ExprType::Alias(alias) => { + Ok( + parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)? + .alias_qualified( + alias + .relation + .first() + .map(|r| TableReference::try_from(r.clone())) + .transpose()?, + alias.alias.clone(), + ), + ) + } + ExprType::IsNullExpr(is_null) => Ok(Expr::_is_null(Box::new( + parse_required_expr(is_null.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( - is_null.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( + ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::_is_not_null(Box::new( parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( + ExprType::NotExpr(not) => Ok(Expr::_not(Box::new(parse_required_expr( not.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( - msg.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( - msg.expr.as_deref(), - registry, - "expr", - codec, - )?))), - ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( + ExprType::IsTrue(msg) => Ok(Expr::_is_true(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( + ExprType::IsFalse(msg) => Ok(Expr::_is_false(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( + ExprType::IsUnknown(msg) => Ok(Expr::_is_unknown(Box::new(parse_required_expr( msg.expr.as_deref(), registry, "expr", codec, )?))), - ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( + ExprType::IsNotTrue(msg) => Ok(Expr::_is_not_true(Box::new( + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + ))), + ExprType::IsNotFalse(msg) => Ok(Expr::_is_not_false(Box::new( + parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + ))), + ExprType::IsNotUnknown(msg) => Ok(Expr::_is_not_unknown(Box::new( parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, ))), - ExprType::Between(between) => Ok(Expr::Between(Between::new( + ExprType::Between(between) => Ok(Expr::_between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), registry, @@ -391,7 +390,7 @@ pub fn parse_expr( codec, )?), ))), - ExprType::Like(like) => Ok(Expr::Like(Like::new( + ExprType::Like(like) => Ok(Expr::_like(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -408,7 +407,7 @@ pub fn parse_expr( parse_escape_char(&like.escape_char)?, false, ))), - ExprType::Ilike(like) => Ok(Expr::Like(Like::new( + ExprType::Ilike(like) => Ok(Expr::_like(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -425,7 +424,7 @@ pub fn parse_expr( parse_escape_char(&like.escape_char)?, true, ))), - ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( + ExprType::SimilarTo(like) => Ok(Expr::similar_to(Like::new( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), @@ -462,7 +461,7 @@ pub fn parse_expr( Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), when_then_expr, parse_optional_expr(case.else_expr.as_deref(), registry, codec)? @@ -477,7 +476,7 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + Ok(Expr::cast(Cast::new(expr, data_type))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -487,9 +486,9 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + Ok(Expr::try_cast(TryCast::new(expr, data_type))) } - ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( + ExprType::Negative(negative) => Ok(Expr::negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { @@ -497,9 +496,9 @@ pub fn parse_expr( if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } - Ok(Expr::Unnest(Unnest::new(exprs.swap_remove(0)))) + Ok(Expr::unnest(Unnest::new(exprs.swap_remove(0)))) } - ExprType::InList(in_list) => Ok(Expr::InList(InList::new( + ExprType::InList(in_list) => Ok(Expr::_in_list(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), registry, @@ -511,7 +510,7 @@ pub fn parse_expr( ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier, options: WildcardOptions::default(), })) @@ -525,7 +524,7 @@ pub fn parse_expr( Some(buf) => codec.try_decode_udf(fun_name, buf)?, None => registry.udf(fun_name.as_str())?, }; - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Ok(Expr::scalar_function(expr::ScalarFunction::new_udf( scalar_fn, parse_exprs(args, registry, codec)?, ))) @@ -536,7 +535,7 @@ pub fn parse_expr( None => registry.udaf(&pb.fun_name)?, }; - Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Ok(Expr::aggregate_function(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, pb.distinct, @@ -550,21 +549,21 @@ pub fn parse_expr( } ExprType::GroupingSet(GroupingSetNode { expr }) => { - Ok(Expr::GroupingSet(GroupingSets( + Ok(Expr::grouping_set(GroupingSets( expr.iter() .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) .collect::, Error>>()?, ))) } - ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( + ExprType::Cube(CubeNode { expr }) => Ok(Expr::grouping_set(GroupingSet::Cube( parse_exprs(expr, registry, codec)?, ))), - ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( + ExprType::Rollup(RollupNode { expr }) => Ok(Expr::grouping_set( GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { - None => Ok(Expr::Placeholder(Placeholder::new(id.clone(), None))), - Some(data_type) => Ok(Expr::Placeholder(Placeholder::new( + None => Ok(Expr::placeholder(Placeholder::new(id.clone(), None))), + Some(data_type) => Ok(Expr::placeholder(Placeholder::new( id.clone(), Some(data_type.try_into()?), ))), diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 50636048ebc96..0aea72c394651 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -297,7 +297,7 @@ impl AsLogicalPlan for LogicalPlanNode { match projection.optional_alias.as_ref() { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Ok(LogicalPlan::subquery_alias(SubqueryAlias::try_new( Arc::new(new_proj), alias.clone(), )?)) @@ -567,7 +567,7 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults.insert(col_name.clone(), expr); } - Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateExternalTable( CreateExternalTable { schema: pb_schema.try_into()?, name: from_table_reference( @@ -602,7 +602,7 @@ impl AsLogicalPlan for LogicalPlanNode { None }; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { name: from_table_reference(create_view.name.as_ref(), "CreateView")?, temporary: create_view.temporary, input: Arc::new(plan), @@ -617,7 +617,7 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( + Ok(LogicalPlan::ddl(DdlStatement::CreateCatalogSchema( CreateCatalogSchema { schema_name: create_catalog_schema.schema_name.clone(), if_not_exists: create_catalog_schema.if_not_exists, @@ -632,7 +632,7 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog( + Ok(LogicalPlan::ddl(DdlStatement::CreateCatalog( CreateCatalog { catalog_name: create_catalog.catalog_name.clone(), if_not_exists: create_catalog.if_not_exists, @@ -772,7 +772,7 @@ impl AsLogicalPlan for LogicalPlanNode { let extension_node = extension_codec.try_decode(node, &input_plans, ctx)?; - Ok(LogicalPlan::Extension(extension_node)) + Ok(LogicalPlan::extension(extension_node)) } LogicalPlanType::Distinct(distinct) => { let input: LogicalPlan = @@ -848,7 +848,7 @@ impl AsLogicalPlan for LogicalPlanNode { .build() } LogicalPlanType::DropView(dropview) => { - Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { + Ok(LogicalPlan::ddl(DdlStatement::DropView(DropView { name: from_table_reference(dropview.name.as_ref(), "DropView")?, if_exists: dropview.if_exists, schema: Arc::new(convert_required!(dropview.schema)?), @@ -862,7 +862,7 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec.try_decode_file_format(©.file_type, ctx)?, ); - Ok(LogicalPlan::Copy(dml::CopyTo { + Ok(LogicalPlan::copy(dml::CopyTo { input: Arc::new(input), output_url: copy.output_url.clone(), partition_by: copy.partition_by.clone(), @@ -873,7 +873,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Unnest(unnest) => { let input: LogicalPlan = into_logical_plan!(unnest.input, ctx, extension_codec)?; - Ok(LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::unnest(Unnest { input: Arc::new(input), exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest @@ -925,7 +925,7 @@ impl AsLogicalPlan for LogicalPlanNode { )))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + Ok(LogicalPlan::recursive_query(RecursiveQuery { name: recursive_query_node.name.clone(), static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), @@ -954,7 +954,7 @@ impl AsLogicalPlan for LogicalPlanNode { Self: Sized, { match plan { - LogicalPlan::Values(Values { values, .. }) => { + LogicalPlan::Values(Values { values, .. }, _) => { let n_cols = if values.is_empty() { 0 } else { @@ -971,13 +971,16 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::TableScan(TableScan { - table_name, - source, - filters, - projection, - .. - }) => { + LogicalPlan::TableScan( + TableScan { + table_name, + source, + filters, + projection, + .. + }, + _, + ) => { let provider = source_as_provider(source)?; let schema = provider.schema(); let source = provider.as_any(); @@ -1131,7 +1134,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(node) } } - LogicalPlan::Projection(Projection { expr, input, .. }) => { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Projection(Box::new( protobuf::ProjectionNode { @@ -1147,7 +1150,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( filter.input.as_ref(), extension_codec, @@ -1164,7 +1167,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct::All(input)) => { + LogicalPlan::Distinct(Distinct::All(input), _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1177,13 +1180,16 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - .. - })) => { + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + }), + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1203,9 +1209,12 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Window(Window { - input, window_expr, .. - }) => { + LogicalPlan::Window( + Window { + input, window_expr, .. + }, + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1219,12 +1228,15 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - input, - .. - }) => { + LogicalPlan::Aggregate( + Aggregate { + group_expr, + aggr_expr, + input, + .. + }, + _, + ) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1239,16 +1251,19 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - null_equals_null, - .. - }) => { + LogicalPlan::Join( + Join { + left, + right, + on, + filter, + join_type, + join_constraint, + null_equals_null, + .. + }, + _, + ) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( left.as_ref(), extension_codec, @@ -1290,10 +1305,10 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { + LogicalPlan::Subquery(_, _) => { not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") } - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1307,7 +1322,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( limit.input.as_ref(), extension_codec, @@ -1333,7 +1348,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Sort(Sort { input, expr, fetch }) => { + LogicalPlan::Sort(Sort { input, expr, fetch }, _) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, @@ -1350,10 +1365,13 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => { + LogicalPlan::Repartition( + Repartition { + input, + partitioning_scheme, + }, + _, + ) => { use datafusion::logical_expr::Partitioning; let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -1388,17 +1406,20 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, .. - }) => Ok(LogicalPlanNode { + LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row, .. + }, + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::EmptyRelation( protobuf::EmptyRelationNode { produce_one_row: *produce_one_row, }, )), }), - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { + LogicalPlan::Ddl( + DdlStatement::CreateExternalTable(CreateExternalTable { name, location, file_type, @@ -1412,8 +1433,9 @@ impl AsLogicalPlan for LogicalPlanNode { constraints, column_defaults, temporary, - }, - )) => { + }), + _, + ) => { let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { let temp = SortExprNodeCollection { @@ -1449,13 +1471,16 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - name, - input, - or_replace, - definition, - temporary, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::CreateView(CreateView { + name, + input, + or_replace, + definition, + temporary, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { name: Some(name.clone().into()), @@ -1469,13 +1494,14 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }), - LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( - CreateCatalogSchema { + LogicalPlan::Ddl( + DdlStatement::CreateCatalogSchema(CreateCatalogSchema { schema_name, if_not_exists, schema: df_schema, - }, - )) => Ok(LogicalPlanNode { + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( protobuf::CreateCatalogSchemaNode { schema_name: schema_name.clone(), @@ -1484,11 +1510,14 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::CreateCatalog(CreateCatalog { - catalog_name, - if_not_exists, - schema: df_schema, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::CreateCatalog(CreateCatalog { + catalog_name, + if_not_exists, + schema: df_schema, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalog( protobuf::CreateCatalogNode { catalog_name: catalog_name.clone(), @@ -1497,7 +1526,7 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Analyze(a) => { + LogicalPlan::Analyze(a, _) => { let input = LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), extension_codec, @@ -1511,7 +1540,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Explain(a) => { + LogicalPlan::Explain(a, _) => { let input = LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), extension_codec, @@ -1525,7 +1554,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { let inputs: Vec = union .inputs .iter() @@ -1537,7 +1566,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Extension(extension) => { + LogicalPlan::Extension(extension, _) => { let mut buf: Vec = vec![]; extension_codec.try_encode(extension, &mut buf)?; @@ -1554,11 +1583,14 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::Statement(Statement::Prepare(Prepare { - name, - data_types, - input, - })) => { + LogicalPlan::Statement( + Statement::Prepare(Prepare { + name, + data_types, + input, + }), + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; Ok(LogicalPlanNode { @@ -1574,15 +1606,18 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Unnest(Unnest { - input, - exec_columns, - list_type_columns, - struct_type_columns, - dependency_indices, - schema, - options, - }) => { + LogicalPlan::Unnest( + Unnest { + input, + exec_columns, + list_type_columns, + struct_type_columns, + dependency_indices, + schema, + options, + }, + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let proto_unnest_list_items = list_type_columns @@ -1618,20 +1653,23 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateMemoryTable", )), - LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateIndex(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateIndex", )), - LogicalPlan::Ddl(DdlStatement::DropTable(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropTable(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropTable", )), - LogicalPlan::Ddl(DdlStatement::DropView(DropView { - name, - if_exists, - schema, - })) => Ok(LogicalPlanNode { + LogicalPlan::Ddl( + DdlStatement::DropView(DropView { + name, + if_exists, + schema, + }), + _, + ) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { name: Some(name.clone().into()), @@ -1640,28 +1678,31 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropCatalogSchema", )), - LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::CreateFunction(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateFunction", )), - LogicalPlan::Ddl(DdlStatement::DropFunction(_)) => Err(proto_error( + LogicalPlan::Ddl(DdlStatement::DropFunction(_), _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropFunction", )), - LogicalPlan::Statement(_) => Err(proto_error( + LogicalPlan::Statement(_, _) => Err(proto_error( "LogicalPlan serde is not yet implemented for Statement", )), - LogicalPlan::Dml(_) => Err(proto_error( + LogicalPlan::Dml(_, _) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(dml::CopyTo { - input, - output_url, - file_type, - partition_by, - .. - }) => { + LogicalPlan::Copy( + dml::CopyTo { + input, + output_url, + file_type, + partition_by, + .. + }, + _, + ) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let mut buf = Vec::new(); @@ -1679,10 +1720,10 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::DescribeTable(_) => Err(proto_error( + LogicalPlan::DescribeTable(_, _) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), - LogicalPlan::RecursiveQuery(recursive) => { + LogicalPlan::RecursiveQuery(recursive, _) => { let static_term = LogicalPlanNode::try_from_logical_plan( recursive.static_term.as_ref(), extension_codec, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b3c91207ecf32..17343b98e6421 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -191,14 +191,17 @@ pub fn serialize_expr( use protobuf::logical_expr_node::ExprType; let expr_node = match expr { - Expr::Column(c) => protobuf::LogicalExprNode { + Expr::Column(c, _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { - expr, - relation, - name, - }) => { + Expr::Alias( + Alias { + expr, + relation, + name, + }, + _, + ) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), relation: relation @@ -211,22 +214,25 @@ pub fn serialize_expr( expr_type: Some(ExprType::Alias(alias)), } } - Expr::Literal(value) => { + Expr::Literal(value, _) => { let pb_value: protobuf::ScalarValue = value.try_into()?; protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { // Try to linerize a nested binary expression tree of the same operator // into a flat vector of expressions. let mut exprs = vec![right.as_ref()]; let mut current_expr = left.as_ref(); - while let Expr::BinaryExpr(BinaryExpr { - left, - op: current_op, - right, - }) = current_expr + while let Expr::BinaryExpr( + BinaryExpr { + left, + op: current_op, + right, + }, + _, + ) = current_expr { if current_op == op { exprs.push(right.as_ref()); @@ -248,13 +254,16 @@ pub fn serialize_expr( expr_type: Some(ExprType::BinaryExpr(binary_expr)), } } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => { if *case_insensitive { let pb = Box::new(protobuf::ILikeNode { negated: *negated, @@ -279,13 +288,16 @@ pub fn serialize_expr( } } } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => { let pb = Box::new(protobuf::SimilarToNode { negated: *negated, expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), @@ -296,15 +308,18 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { + Expr::WindowFunction( + expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + _, + ) => { let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); @@ -344,14 +359,17 @@ pub fn serialize_expr( expr_type: Some(ExprType::WindowExpr(window_expr)), } } - Expr::AggregateFunction(expr::AggregateFunction { - ref func, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, - }) => { + Expr::AggregateFunction( + expr::AggregateFunction { + ref func, + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, + _, + ) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); protobuf::LogicalExprNode { @@ -374,12 +392,12 @@ pub fn serialize_expr( } } - Expr::ScalarVariable(_, _) => { + Expr::ScalarVariable(_, _, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported".to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let mut buf = Vec::new(); let _ = codec.try_encode_udf(func, &mut buf); protobuf::LogicalExprNode { @@ -390,7 +408,7 @@ pub fn serialize_expr( })), } } - Expr::Not(expr) => { + Expr::Not(expr, _) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -398,7 +416,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::NotExpr(expr)), } } - Expr::IsNull(expr) => { + Expr::IsNull(expr, _) => { let expr = Box::new(protobuf::IsNull { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -406,7 +424,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNullExpr(expr)), } } - Expr::IsNotNull(expr) => { + Expr::IsNotNull(expr, _) => { let expr = Box::new(protobuf::IsNotNull { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -414,7 +432,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotNullExpr(expr)), } } - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { let expr = Box::new(protobuf::IsTrue { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -422,7 +440,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsTrue(expr)), } } - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { let expr = Box::new(protobuf::IsFalse { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -430,7 +448,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsFalse(expr)), } } - Expr::IsUnknown(expr) => { + Expr::IsUnknown(expr, _) => { let expr = Box::new(protobuf::IsUnknown { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -438,7 +456,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsUnknown(expr)), } } - Expr::IsNotTrue(expr) => { + Expr::IsNotTrue(expr, _) => { let expr = Box::new(protobuf::IsNotTrue { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -446,7 +464,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotTrue(expr)), } } - Expr::IsNotFalse(expr) => { + Expr::IsNotFalse(expr, _) => { let expr = Box::new(protobuf::IsNotFalse { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -454,7 +472,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotFalse(expr)), } } - Expr::IsNotUnknown(expr) => { + Expr::IsNotUnknown(expr, _) => { let expr = Box::new(protobuf::IsNotUnknown { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -462,12 +480,15 @@ pub fn serialize_expr( expr_type: Some(ExprType::IsNotUnknown(expr)), } } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let expr = Box::new(protobuf::BetweenNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), negated: *negated, @@ -478,7 +499,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Between(expr)), } } - Expr::Case(case) => { + Expr::Case(case, _) => { let when_then_expr = case .when_then_expr .iter() @@ -504,7 +525,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.try_into()?), @@ -513,7 +534,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.try_into()?), @@ -522,7 +543,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::TryCast(expr)), } } - Expr::Negative(expr) => { + Expr::Negative(expr, _) => { let expr = Box::new(protobuf::NegativeNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), }); @@ -530,7 +551,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Negative(expr)), } } - Expr::Unnest(Unnest { expr }) => { + Expr::Unnest(Unnest { expr }, _) => { let expr = protobuf::Unnest { exprs: vec![serialize_expr(expr.as_ref(), codec)?], }; @@ -538,11 +559,14 @@ pub fn serialize_expr( expr_type: Some(ExprType::Unnest(expr)), } } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let expr = Box::new(protobuf::InListNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), list: serialize_exprs(list, codec)?, @@ -552,30 +576,30 @@ pub fn serialize_expr( expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard(Wildcard { qualifier, .. }) => protobuf::LogicalExprNode { + Expr::Wildcard(Wildcard { qualifier, .. }, _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) + Expr::ScalarSubquery(_, _) + | Expr::InSubquery(_, _) | Expr::Exists { .. } | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { expr: serialize_exprs(exprs, codec)?, })), }, - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Rollup(RollupNode { expr: serialize_exprs(exprs, codec)?, })), }, - Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(exprs), _) => { protobuf::LogicalExprNode { expr_type: Some(ExprType::GroupingSet(GroupingSetNode { expr: exprs @@ -589,7 +613,7 @@ pub fn serialize_expr( })), } } - Expr::Placeholder(Placeholder { id, data_type }) => { + Expr::Placeholder(Placeholder { id, data_type }, _) => { let data_type = match data_type { Some(data_type) => Some(data_type.try_into()?), None => None, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index deece2e54a5a6..f566235d27547 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -133,7 +133,7 @@ async fn roundtrip_logical_plan() -> Result<()> { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; let scan = ctx.table("t1").await?.into_optimized_plan()?; - let topk_plan = LogicalPlan::Extension(Extension { + let topk_plan = LogicalPlan::extension(Extension { node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), }); let extension_codec = TopKExtensionCodec {}; @@ -380,7 +380,7 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let input = create_csv_scan(&ctx).await?; let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -420,7 +420,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format), )); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), file_type, @@ -434,7 +434,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); @@ -452,7 +452,7 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -467,7 +467,7 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.arrow", copy_to.output_url); assert_eq!("arrow".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -499,7 +499,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -514,7 +514,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.csv", copy_to.output_url); assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -565,7 +565,7 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { json_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.json".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -581,7 +581,7 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.json", copy_to.output_url); assert_eq!("json".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -637,7 +637,7 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format.clone()), )); - let plan = LogicalPlan::Copy(CopyTo { + let plan = LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -653,7 +653,7 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { - LogicalPlan::Copy(copy_to) => { + LogicalPlan::Copy(copy_to, _) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!("parquet".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); @@ -1902,7 +1902,7 @@ fn roundtrip_dfschema() { #[test] fn roundtrip_not() { - let test_expr = Expr::Not(Box::new(lit(1.0_f32))); + let test_expr = Expr::_not(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1910,7 +1910,7 @@ fn roundtrip_not() { #[test] fn roundtrip_is_null() { - let test_expr = Expr::IsNull(Box::new(col("id"))); + let test_expr = Expr::_is_null(Box::new(col("id"))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1918,7 +1918,7 @@ fn roundtrip_is_null() { #[test] fn roundtrip_is_not_null() { - let test_expr = Expr::IsNotNull(Box::new(col("id"))); + let test_expr = Expr::_is_not_null(Box::new(col("id"))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1926,7 +1926,7 @@ fn roundtrip_is_not_null() { #[test] fn roundtrip_between() { - let test_expr = Expr::Between(Between::new( + let test_expr = Expr::_between(Between::new( Box::new(lit(1.0_f32)), true, Box::new(lit(2.0_f32)), @@ -1940,7 +1940,7 @@ fn roundtrip_between() { #[test] fn roundtrip_binary_op() { fn test(op: Operator) { - let test_expr = Expr::BinaryExpr(BinaryExpr::new( + let test_expr = Expr::binary_expr(BinaryExpr::new( Box::new(lit(1.0_f32)), op, Box::new(lit(2.0_f32)), @@ -1974,7 +1974,7 @@ fn roundtrip_binary_op() { #[test] fn roundtrip_case() { - let test_expr = Expr::Case(Case::new( + let test_expr = Expr::case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], Some(Box::new(lit(4.0_f32))), @@ -1986,10 +1986,10 @@ fn roundtrip_case() { #[test] fn roundtrip_case_with_null() { - let test_expr = Expr::Case(Case::new( + let test_expr = Expr::case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::literal(ScalarValue::Null))), )); let ctx = SessionContext::new(); @@ -1998,7 +1998,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::literal(ScalarValue::Null); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2006,7 +2006,7 @@ fn roundtrip_null_literal() { #[test] fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + let test_expr = Expr::cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2015,13 +2015,13 @@ fn roundtrip_cast() { #[test] fn roundtrip_try_cast() { let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + Expr::try_cast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); + Expr::try_cast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2029,7 +2029,7 @@ fn roundtrip_try_cast() { #[test] fn roundtrip_negative() { - let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); + let test_expr = Expr::negative(Box::new(lit(1.0_f32))); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2037,7 +2037,7 @@ fn roundtrip_negative() { #[test] fn roundtrip_inlist() { - let test_expr = Expr::InList(InList::new( + let test_expr = Expr::_in_list(InList::new( Box::new(lit(1.0_f32)), vec![lit(2.0_f32)], true, @@ -2049,7 +2049,7 @@ fn roundtrip_inlist() { #[test] fn roundtrip_unnest() { - let test_expr = Expr::Unnest(Unnest { + let test_expr = Expr::unnest(Unnest { expr: Box::new(col("col")), }); @@ -2059,7 +2059,7 @@ fn roundtrip_unnest() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard(Wildcard { + let test_expr = Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), }); @@ -2070,7 +2070,7 @@ fn roundtrip_wildcard() { #[test] fn roundtrip_qualified_wildcard() { - let test_expr = Expr::Wildcard(Wildcard { + let test_expr = Expr::wildcard(Wildcard { qualifier: Some("foo".into()), options: WildcardOptions::default(), }); @@ -2082,7 +2082,7 @@ fn roundtrip_qualified_wildcard() { #[test] fn roundtrip_like() { fn like(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( + let test_expr = Expr::_like(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2101,7 +2101,7 @@ fn roundtrip_like() { #[test] fn roundtrip_ilike() { fn ilike(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( + let test_expr = Expr::_like(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2120,7 +2120,7 @@ fn roundtrip_ilike() { #[test] fn roundtrip_similar_to() { fn similar_to(negated: bool, escape_char: Option) { - let test_expr = Expr::SimilarTo(Like::new( + let test_expr = Expr::similar_to(Like::new( negated, Box::new(col("col")), Box::new(lit("[0-9]+")), @@ -2195,7 +2195,7 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let test_expr = Expr::aggregate_function(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], false, @@ -2227,7 +2227,7 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr = Expr::scalar_function(ScalarFunction::new_udf( Arc::new(udf.clone()), vec![lit("")], )); @@ -2266,7 +2266,7 @@ fn roundtrip_aggregate_udf_extension_codec() { #[test] fn roundtrip_grouping_sets() { - let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + let test_expr = Expr::grouping_set(GroupingSet::GroupingSets(vec![ vec![col("a")], vec![col("b")], vec![col("a"), col("b")], @@ -2278,7 +2278,7 @@ fn roundtrip_grouping_sets() { #[test] fn roundtrip_rollup() { - let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + let test_expr = Expr::grouping_set(GroupingSet::Rollup(vec![col("a"), col("b")])); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2286,7 +2286,7 @@ fn roundtrip_rollup() { #[test] fn roundtrip_cube() { - let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + let test_expr = Expr::grouping_set(GroupingSet::Cube(vec![col("a"), col("b")])); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2305,13 +2305,13 @@ fn roundtrip_substr() { .unwrap(); // substr(string, position) - let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr = Expr::scalar_function(ScalarFunction::new_udf( fun.clone(), vec![col("col"), lit(1_i64)], )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + let test_expr_with_count = Expr::scalar_function(ScalarFunction::new_udf( fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); @@ -2324,7 +2324,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2335,7 +2335,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2352,7 +2352,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2369,7 +2369,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2419,7 +2419,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2500,7 +2500,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2510,7 +2510,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::window_function(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d1b50105d053d..74c0db34d54eb 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -44,7 +44,7 @@ fn plan_to_json() { use datafusion_expr::{logical_plan::EmptyRelation, LogicalPlan}; use datafusion_proto::bytes::logical_plan_to_json; - let plan = LogicalPlan::EmptyRelation(EmptyRelation { + let plan = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); @@ -62,7 +62,7 @@ fn json_to_plan() { let input = r#"{"emptyRelation":{}}"#.to_string(); let ctx = SessionContext::new(); let actual = logical_plan_from_json(&input, &ctx).unwrap(); - let result = matches!(actual, LogicalPlan::EmptyRelation(_)); + let result = matches!(actual, LogicalPlan::EmptyRelation(_, _)); assert!(result, "Should parse empty relation"); } @@ -256,12 +256,12 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::literal(ScalarValue::Utf8(None)); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); - let expr = Expr::ScalarFunction(ScalarFunction::new_udf(function, args)); + let expr = Expr::scalar_function(ScalarFunction::new_udf(function, args)); let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index c288d6ca70674..71c94890b93c6 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -193,7 +193,7 @@ fn has_work_table_reference( ) -> bool { let mut has_reference = false; plan.apply(|node| { - if let LogicalPlan::TableScan(scan) = node { + if let LogicalPlan::TableScan(scan, _) = node { if Arc::ptr_eq(&scan.source, work_table_source) { has_reference = true; return Ok(TreeNodeRecursion::Stop); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0ce7c891a6085..e7c6c76f546a3 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,8 +22,8 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; -use datafusion_expr::expr::{Wildcard, WildcardOptions}; use datafusion_expr::expr::{ScalarFunction, Unnest}; +use datafusion_expr::expr::{Wildcard, WildcardOptions}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -235,7 +235,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); + return Ok(Expr::scalar_function(ScalarFunction::new_udf(fm, args))); } // Build Unnest expression @@ -246,7 +246,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let expr = exprs.swap_remove(0); Self::check_unnest_arg(&expr, schema)?; - return Ok(Expr::Unnest(Unnest::new(expr))); + return Ok(Expr::unnest(Unnest::new(expr))); } if !order_by.is_empty() && is_function_window { @@ -277,7 +277,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Column(col) = &orderby_expr.expr { + if let Expr::Column(col, _) = &orderby_expr.expr { let idx = schema.index_of_column(col).ok()?; return if func_deps.iter().any(|dep| { dep.source_indices == vec![idx] && dep.mode == Dependency::Single @@ -310,7 +310,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - return Expr::WindowFunction(expr::WindowFunction::new( + return Expr::window_function(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, )) @@ -336,7 +336,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? .map(Box::new); - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + return Ok(Expr::aggregate_function(expr::AggregateFunction::new_udf( fm, args, distinct, @@ -371,7 +371,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { internal_datafusion_err!("Unable to find expected '{fn_name}' function") })?; let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + Ok(Expr::scalar_function(ScalarFunction::new_udf(fun, args))) } pub(super) fn find_window_func( @@ -413,7 +413,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name: _, arg: FunctionArgExpr::Wildcard, operator: _, - } => Ok(Expr::Wildcard(Wildcard { + } => Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })), @@ -421,7 +421,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(arg, schema, planner_context) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })) @@ -433,7 +433,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if qualified_indices.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } - Ok(Expr::Wildcard(Wildcard { + Ok(Expr::wildcard(Wildcard { qualifier: Some(qualifier), options: WildcardOptions::default(), })) diff --git a/datafusion/sql/src/expr/grouping_set.rs b/datafusion/sql/src/expr/grouping_set.rs index a8b3ef7e20ec2..ad8613e4596cc 100644 --- a/datafusion/sql/src/expr/grouping_set.rs +++ b/datafusion/sql/src/expr/grouping_set.rs @@ -36,7 +36,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect() }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::GroupingSets(args?))) + Ok(Expr::grouping_set(GroupingSet::GroupingSets(args?))) } pub(super) fn sql_rollup_to_expr( @@ -57,7 +57,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::Rollup(args?))) + Ok(Expr::grouping_set(GroupingSet::Rollup(args?))) } pub(super) fn sql_cube_to_expr( @@ -76,6 +76,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }) .collect(); - Ok(Expr::GroupingSet(GroupingSet::Cube(args?))) + Ok(Expr::grouping_set(GroupingSet::Cube(args?))) } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index e103f68fc9275..4d0eda55fc494 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -44,7 +44,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .ok_or_else(|| { plan_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::scalar_variable(ty, var_names)) } else { // Don't use `col()` here because it will try to // interpret names with '.' as if they were @@ -56,7 +56,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) { - return Ok(Expr::Column(Column { + return Ok(Expr::column(Column { relation: qualifier.filter(|q| q.table() != UNNAMED_TABLE).cloned(), name: normalize_ident, })); @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - return Ok(Expr::OuterReferenceColumn( + return Ok(Expr::outer_reference_column( field.data_type().clone(), Column::from((qualifier, field)), )); @@ -76,7 +76,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // Default case - Ok(Expr::Column(Column { + Ok(Expr::column(Column { relation: None, name: normalize_ident, })) @@ -106,7 +106,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "variable {var_names:?} has no type information" )) })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::scalar_variable(ty, var_names)) } else { let ids = ids .into_iter() @@ -136,7 +136,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - Ok(Expr::Column(Column::from((qualifier, field)))) + Ok(Expr::column(Column::from((qualifier, field)))) } None => { // Return default where use all identifiers to not have a nested field @@ -161,7 +161,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( + Ok(Expr::outer_reference_column( field.data_type().clone(), Column::from((qualifier, field)), )) @@ -172,14 +172,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // safe unwrap as s can never be empty or exceed the bounds let (relation, column_name) = form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) + Ok(Expr::column(Column::new(relation, column_name))) } } } else { let s = &ids[0..ids.len()]; // Safe unwrap as s can never be empty or exceed the bounds let (relation, column_name) = form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) + Ok(Expr::column(Column::new(relation, column_name))) } } } @@ -223,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None }; - Ok(Expr::Case(Case::new( + Ok(Expr::case(Case::new( expr, when_expr .iter() diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 95ee42cc99cf2..822029d9f1b2a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -124,7 +124,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let RawBinaryExpr { op, left, right } = binary_expr; - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(left), self.parse_sql_binary_op(op)?, Box::new(right), @@ -148,7 +148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Rewrite aliases which are not-complete (e.g. ones that only include only table qualifier in a schema.table qualified relation) fn rewrite_partial_qualifier(&self, expr: Expr, schema: &DFSchema) -> Expr { match expr { - Expr::Column(col) => match &col.relation { + Expr::Column(col, _) => match &col.relation { Some(q) => { match schema.iter().find(|(qualifier, field)| match qualifier { Some(field_q) => { @@ -158,10 +158,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => false, }) { Some((qualifier, df_field)) => Expr::from((qualifier, df_field)), - None => Expr::Column(col), + None => Expr::column(col), } } - None => Expr::Column(col), + None => Expr::column(col), }, _ => expr, } @@ -187,7 +187,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; @@ -253,7 +253,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("CAST with format is not supported: {format}"); } - Ok(Expr::TryCast(TryCast::new( + Ok(Expr::try_cast(TryCast::new( Box::new(self.sql_expr_to_logical_expr( *expr, schema, @@ -263,21 +263,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast(Cast::new( + SQLExpr::TypedString { data_type, value } => Ok(Expr::cast(Cast::new( Box::new(lit(value)), self.convert_data_type(&data_type)?, ))), - SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( + SQLExpr::IsNull(expr) => Ok(Expr::_is_null(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new( + SQLExpr::IsNotNull(expr) => Ok(Expr::_is_not_null(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), SQLExpr::IsDistinctFrom(left, right) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(self.sql_expr_to_logical_expr( *left, schema, @@ -293,7 +293,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::IsNotDistinctFrom(left, right) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Expr::binary_expr(BinaryExpr::new( Box::new(self.sql_expr_to_logical_expr( *left, schema, @@ -308,27 +308,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new( + SQLExpr::IsTrue(expr) => Ok(Expr::_is_true(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new( + SQLExpr::IsFalse(expr) => Ok(Expr::_is_false(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new( + SQLExpr::IsNotTrue(expr) => Ok(Expr::_is_not_true(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new( + SQLExpr::IsNotFalse(expr) => Ok(Expr::_is_not_false(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new( + SQLExpr::IsUnknown(expr) => Ok(Expr::_is_unknown(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new( + SQLExpr::IsNotUnknown(expr) => Ok(Expr::_is_not_unknown(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), @@ -341,7 +341,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated, low, high, - } => Ok(Expr::Between(Between::new( + } => Ok(Expr::_between(Between::new( Box::new(self.sql_expr_to_logical_expr( *expr, schema, @@ -509,7 +509,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::AtTimeZone { timestamp, time_zone, - } => Ok(Expr::Cast(Cast::new( + } => Ok(Expr::cast(Cast::new( Box::new(self.sql_expr_to_logical_expr_internal( *timestamp, schema, @@ -565,11 +565,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") } - SQLExpr::Wildcard => Ok(Expr::Wildcard(Wildcard { + SQLExpr::Wildcard => Ok(Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })), - SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::Wildcard(Wildcard { + SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::wildcard(Wildcard { qualifier: Some(self.object_name_to_table_reference(object_name)?), options: WildcardOptions::default(), })), @@ -769,7 +769,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - Ok(Expr::InList(InList::new( + Ok(Expr::_in_list(InList::new( Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), list_expr, negated, @@ -804,7 +804,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { None }; - Ok(Expr::Like(Like::new( + Ok(Expr::_like(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), @@ -835,7 +835,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { None }; - Ok(Expr::SimilarTo(Like::new( + Ok(Expr::similar_to(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), @@ -891,7 +891,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { internal_datafusion_err!("Unable to find expected '{fun_name}' function") })?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + Ok(Expr::scalar_function(ScalarFunction::new_udf(fun, args))) } fn sql_overlay_to_expr( @@ -946,7 +946,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { DataType::Timestamp(TimeUnit::Nanosecond, tz) if expr.get_type(schema)? == DataType::Int64 => { - Expr::Cast(Cast::new( + Expr::cast(Cast::new( Box::new(expr), DataType::Timestamp(TimeUnit::Second, tz.clone()), )) @@ -954,7 +954,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => expr, }; - Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + Ok(Expr::cast(Cast::new(Box::new(expr), dt))) } fn sql_subscript_to_expr( diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 00289806876fe..9553ee446dc4c 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -90,7 +90,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - Expr::Column(Column::from( + Expr::column(Column::from( input_schema.qualified_field(field_index - 1), )) } diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index ff161c6ed644e..45df1e27acab1 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -37,7 +37,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); - Ok(Expr::Exists(Exists { + Ok(Expr::exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), outer_ref_columns, @@ -60,7 +60,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); let expr = Box::new(self.sql_to_expr(expr, input_schema, planner_context)?); - Ok(Expr::InSubquery(InSubquery::new( + Ok(Expr::in_subquery(InSubquery::new( expr, Subquery { subquery: Arc::new(sub_plan), @@ -81,7 +81,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); - Ok(Expr::ScalarSubquery(Subquery { + Ok(Expr::scalar_subquery(Subquery { subquery: Arc::new(sub_plan), outer_ref_columns, })) diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index f58ab5ff3612c..02a78e16d1bb6 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -51,7 +51,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::literal(ScalarValue::Int64(Some(1))); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 06988eb03893b..5b9a3d3151951 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -32,7 +32,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { match op { - UnaryOperator::Not => Ok(Expr::Not(Box::new( + UnaryOperator::Not => Ok(Expr::_not(Box::new( self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), UnaryOperator::Plus => { @@ -59,7 +59,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_interval_to_expr(true, interval) } // Not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( + _ => Ok(Expr::negative(Box::new(self.sql_expr_to_logical_expr( expr, schema, planner_context, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 1cf090aa64aa6..aabba604d0ea1 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -41,7 +41,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::literal(ScalarValue::Null)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -112,7 +112,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(index) => index - 1, Err(_) => { return if param_data_types.is_empty() { - Ok(Expr::Placeholder(Placeholder::new(param, None))) + Ok(Expr::placeholder(Placeholder::new(param, None))) } else { // when PREPARE Statement, param_data_types length is always 0 plan_err!("Invalid placeholder, not a number: {param}") @@ -127,7 +127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { param, param_type ); - Ok(Expr::Placeholder(Placeholder::new( + Ok(Expr::placeholder(Placeholder::new( param, param_type.cloned(), ))) @@ -223,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fractional_seconds_precision: None, }, )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( + return Ok(Expr::binary_expr(BinaryExpr::new( Box::new(left_expr), df_op, Box::new(right_expr), @@ -349,7 +349,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { )))); } - Ok(Expr::Literal(ScalarValue::Decimal128( + Ok(Expr::literal(ScalarValue::Decimal128( Some(if negative { -number } else { number }), precision as u8, scale as i8, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ccb2ccf7126f1..d5ae0878fd061 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -380,7 +380,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(col) => match &col.relation { + Expr::Column(col, _) => match &col.relation { Some(r) => { schema.field_with_qualified_name(r, &col.name)?; Ok(()) diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 740f9ad3b42c3..f88096852d621 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -116,11 +116,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(plan); } - if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on), _) = plan { // In case of `DISTINCT ON` we must capture the sort expressions since during the plan // optimization we're effectively doing a `first_value` aggregation according to them. let distinct_on = distinct_on.clone().with_sort_expr(order_by)?; - Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + Ok(LogicalPlan::distinct(Distinct::On(distinct_on))) } else { LogicalPlanBuilder::from(plan).sort(order_by)?.build() } @@ -133,7 +133,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_into: Option, ) -> Result { match select_into { - Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Some(into) => Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(into.name)?, constraints: Constraints::empty(), diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 256cc58e71dc4..54ba85bb3f1e6 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -128,7 +128,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; Self::check_unnest_arg(&expr, &schema)?; - Ok(Expr::Unnest(Unnest::new(expr))) + Ok(Expr::unnest(Unnest::new(expr))) }) .collect::>>()?; if unnest_exprs.is_empty() { @@ -189,16 +189,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context.set_outer_from_schema(Some(old_from_schema)); match plan { - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { subquery_alias( - LogicalPlan::Subquery(Subquery { + LogicalPlan::subquery(Subquery { subquery: input, outer_ref_columns, }), alias, ) } - plan => Ok(LogicalPlan::Subquery(Subquery { + plan => Ok(LogicalPlan::subquery(Subquery { subquery: Arc::new(plan), outer_ref_columns, })), @@ -215,17 +215,17 @@ fn optimize_subquery_sort(plan: LogicalPlan) -> Result> // 3. LIMIT => Handled by a `Sort`, so we need to search for it. let mut has_limit = false; let new_plan = plan.transform_down(|c| { - if let LogicalPlan::Limit(_) = c { + if let LogicalPlan::Limit(_, _) = c { has_limit = true; return Ok(Transformed::no(c)); } match c { - LogicalPlan::Sort(s) => { + LogicalPlan::Sort(s, _) => { if !has_limit { has_limit = false; return Ok(Transformed::yes(s.input.as_ref().clone())); } - Ok(Transformed::no(LogicalPlan::Sort(s))) + Ok(Transformed::no(LogicalPlan::sort(s))) } _ => Ok(Transformed::no(c)), } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 12fc013a2ab80..21d27a74dd952 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -72,7 +72,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; - let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_, _)); // Process `where` clause let base_plan = self.plan_selection(select.selection, plan, planner_context)?; @@ -178,9 +178,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) => false, - Expr::Alias(Alias { expr, name: _, .. }) => { - !matches!(**expr, Expr::AggregateFunction(_)) + Expr::AggregateFunction(_, _) => false, + Expr::Alias(Alias { expr, name: _, .. }, _) => { + !matches!(**expr, Expr::AggregateFunction(_, _)) } _ => true, }) @@ -364,7 +364,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { match input { - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = self.try_process_group_by_unnest(agg)?; @@ -372,12 +372,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .aggregate(new_group_by_exprs, agg_expr)? .build() } - LogicalPlan::Filter(mut filter) => { + LogicalPlan::Filter(mut filter, _) => { filter.input = Arc::new(self.try_process_aggregate_unnest(Arc::unwrap_or_clone( filter.input, ))?); - Ok(LogicalPlan::Filter(filter)) + Ok(LogicalPlan::filter(filter)) } _ => Ok(input), } @@ -440,8 +440,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut columns = HashSet::new(); for expr in &aggr_expr { expr.apply(|expr| { - if let Expr::Column(c) = expr { - columns.insert(Expr::Column(c.clone())); + if let Expr::Column(c, _) = expr { + columns.insert(Expr::column(c.clone())); } Ok(TreeNodeRecursion::Continue) }) @@ -519,7 +519,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[using_columns], )?; - Ok(LogicalPlan::Filter(Filter::try_new( + Ok(LogicalPlan::filter(Filter::try_new( filter_expr, Arc::new(plan), )?)) @@ -611,7 +611,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let name = self.ident_normalizer.normalize(alias); // avoiding adding an alias if the column name is the same. let expr = match &col { - Expr::Column(column) if column.name.eq(&name) => col, + Expr::Column(column, _) if column.name.eq(&name) => col, _ => col.alias(name), }; Ok(vec![expr]) @@ -746,7 +746,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; - let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + let group_by_exprs = if let LogicalPlan::Aggregate(agg, _) = &plan { &agg.group_expr } else { unreachable!(); @@ -762,13 +762,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut aggr_projection_exprs = vec![]; for expr in group_by_exprs { match expr { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { aggr_projection_exprs.extend_from_slice(exprs) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { aggr_projection_exprs.extend_from_slice(exprs) } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { for exprs in lists_of_exprs { aggr_projection_exprs.extend_from_slice(exprs) } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 31b836f32b242..1b9fbaed7db58 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -443,7 +443,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan.schema(), )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, @@ -461,12 +461,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { produce_one_row: false, schema, }; - let plan = LogicalPlan::EmptyRelation(plan); + let plan = LogicalPlan::empty_relation(plan); let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, @@ -530,7 +530,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut plan = self.query_to_plan(*query, &mut PlannerContext::new())?; plan = self.apply_expr_alias(plan, columns)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + Ok(LogicalPlan::ddl(DdlStatement::CreateView(CreateView { name: self.object_name_to_table_reference(name)?, input: Arc::new(plan), or_replace, @@ -547,7 +547,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::CreateSchema { schema_name, if_not_exists, - } => Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( + } => Ok(LogicalPlan::ddl(DdlStatement::CreateCatalogSchema( CreateCatalogSchema { schema_name: get_schema_name(&schema_name), if_not_exists, @@ -558,7 +558,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { db_name, if_not_exists, .. - } => Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog( + } => Ok(LogicalPlan::ddl(DdlStatement::CreateCatalog( CreateCatalog { catalog_name: object_name_to_string(&db_name), if_not_exists, @@ -587,14 +587,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match object_type { ObjectType::Table => { - Ok(LogicalPlan::Ddl(DdlStatement::DropTable(DropTable { + Ok(LogicalPlan::ddl(DdlStatement::DropTable(DropTable { name, if_exists, schema: DFSchemaRef::new(DFSchema::empty()), }))) } ObjectType::View => { - Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { + Ok(LogicalPlan::ddl(DdlStatement::DropView(DropView { name, if_exists, schema: DFSchemaRef::new(DFSchema::empty()), @@ -608,7 +608,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Err(ParserError("Invalid schema specifier (has 3 parts)".to_string())) } }?; - Ok(LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { + Ok(LogicalPlan::ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { name, if_exists, cascade, @@ -640,7 +640,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { *statement, &mut planner_context, )?; - Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { + Ok(LogicalPlan::statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), data_types, input: Arc::new(plan), @@ -667,7 +667,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|expr| self.sql_to_expr(expr, &empty_schema, planner_context)) .collect::>>()?; - Ok(LogicalPlan::Statement(PlanStatement::Execute(Execute { + Ok(LogicalPlan::statement(PlanStatement::Execute(Execute { name: object_name_to_string(&name), parameters, }))) @@ -676,7 +676,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, // Similar to PostgreSQL, the PREPARE keyword is ignored prepare: _, - } => Ok(LogicalPlan::Statement(PlanStatement::Deallocate( + } => Ok(LogicalPlan::statement(PlanStatement::Deallocate( Deallocate { name: ident_to_string(&name), }, @@ -860,14 +860,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { access_mode, isolation_level, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::Commit { chain } => { let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Commit, chain, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::Rollback { chain, savepoint } => { if savepoint.is_some() { @@ -877,7 +877,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { conclusion: TransactionConclusion::Rollback, chain, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } Statement::CreateFunction { or_replace, @@ -971,7 +971,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: DFSchemaRef::new(DFSchema::empty()), }); - Ok(LogicalPlan::Ddl(statement)) + Ok(LogicalPlan::ddl(statement)) } Statement::DropFunction { if_exists, @@ -992,7 +992,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, schema: DFSchemaRef::new(DFSchema::empty()), }); - Ok(LogicalPlan::Ddl(statement)) + Ok(LogicalPlan::ddl(statement)) } else { exec_err!("Function name not provided") } @@ -1021,7 +1021,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { false, None, )?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateIndex( + Ok(LogicalPlan::ddl(DdlStatement::CreateIndex( PlanCreateIndex { name, table, @@ -1097,7 +1097,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let output_schema = DFSchema::try_from(LogicalPlan::describe_schema()).unwrap(); - Ok(LogicalPlan::DescribeTable(DescribeTable { + Ok(LogicalPlan::describe_table(DescribeTable { schema, output_schema: Arc::new(output_schema), })) @@ -1166,7 +1166,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|f| f.name().to_owned()) .collect(); - Ok(LogicalPlan::Copy(CopyTo { + Ok(LogicalPlan::copy(CopyTo { input: Arc::new(input), output_url: statement.target, file_type, @@ -1288,7 +1288,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let name = self.object_name_to_table_reference(name)?; let constraints = Self::new_constraint_from_table_constraints(&all_constraints, &df_schema)?; - Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + Ok(LogicalPlan::ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, name, @@ -1416,7 +1416,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { statement: DFStatement, ) -> Result { let plan = self.statement_to_plan(statement)?; - if matches!(plan, LogicalPlan::Explain(_)) { + if matches!(plan, LogicalPlan::Explain(_, _)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = Arc::new(plan); @@ -1424,7 +1424,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(LogicalPlan::Analyze(Analyze { + Ok(LogicalPlan::analyze(Analyze { verbose, input: plan, schema, @@ -1432,7 +1432,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let stringified_plans = vec![plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(LogicalPlan::Explain(Explain { + Ok(LogicalPlan::explain(Explain { verbose, plan, stringified_plans, @@ -1552,7 +1552,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { value: value_string, }); - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::statement(statement)) } fn delete_to_plan( @@ -1586,11 +1586,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[&schema]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_ref, schema.into(), WriteOp::Delete, @@ -1660,7 +1660,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[scan.schema()]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; @@ -1676,7 +1676,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &mut planner_context, )?; // Update placeholder's datatype to the type of the target column - if let Expr::Placeholder(placeholder) = &mut expr { + if let Expr::Placeholder(placeholder, _) = &mut expr { placeholder.data_type = placeholder .data_type .take() @@ -1688,12 +1688,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => { // If the target table has an alias, use it to qualify the column name if let Some(alias) = &table_alias { - Expr::Column(Column::new( + Expr::column(Column::new( Some(self.ident_normalizer.normalize(alias.name.clone())), field.name(), )) } else { - Expr::Column(Column::from((qualifier, field))) + Expr::column(Column::from((qualifier, field))) } } }; @@ -1703,7 +1703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let source = project(source, exprs)?; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_name, table_schema, WriteOp::Update, @@ -1803,7 +1803,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let target_field = table_schema.field(i); let expr = match value_index { Some(v) => { - Expr::Column(Column::from(source.schema().qualified_field(v))) + Expr::column(Column::from(source.schema().qualified_field(v))) .cast_to(target_field.data_type(), source.schema())? } // The value is not specified. Fill in the default value for the column. @@ -1812,7 +1812,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::literal(ScalarValue::Null) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; @@ -1828,7 +1828,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; - let plan = LogicalPlan::Dml(DmlStatement::new( + let plan = LogicalPlan::dml(DmlStatement::new( table_name, Arc::new(table_schema), WriteOp::Insert(insert_op), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 587e323e162b6..6f90edbb3722a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -93,11 +93,14 @@ impl Unparser<'_> { fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let list_expr = list .iter() .map(|e| self.expr_to_sql_inner(e)) @@ -108,7 +111,7 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }, _) => { let func_name = func.name(); if let Some(expr) = self @@ -120,12 +123,15 @@ impl Unparser<'_> { self.scalar_function_to_sql(func_name, args) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; let sql_low = self.expr_to_sql_inner(low)?; let sql_high = self.expr_to_sql_inner(high)?; @@ -136,19 +142,22 @@ impl Unparser<'_> { sql_high, )))) } - Expr::Column(col) => self.col_to_sql(col), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::Column(col, _) => self.col_to_sql(col), + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let l = self.expr_to_sql_inner(left.as_ref())?; let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { let conditions = when_then_expr .iter() .map(|(w, _)| self.expr_to_sql_inner(w)) @@ -179,19 +188,22 @@ impl Unparser<'_> { else_result, }) } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { Ok(self.cast_to_sql(expr, data_type)?) } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), + Expr::Alias(Alias { expr, name: _, .. }, _) => self.expr_to_sql_inner(expr), + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + _, + ) => { let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -246,27 +258,33 @@ impl Unparser<'_> { parameters: ast::FunctionArguments::None, })) } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) - | Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => Ok(ast::Expr::Like { + Expr::SimilarTo( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) + | Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }, + _, + ) => Ok(ast::Expr::Like { negated: *negated, expr: Box::new(self.expr_to_sql_inner(expr)?), pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), any: false, }), - Expr::AggregateFunction(agg) => { + Expr::AggregateFunction(agg, _) => { let func_name = agg.func.name(); let args = self.function_args_to_sql(&agg.args)?; @@ -293,7 +311,7 @@ impl Unparser<'_> { parameters: ast::FunctionArguments::None, })) } - Expr::ScalarSubquery(subq) => { + Expr::ScalarSubquery(subq, _) => { let sub_statement = self.plan_to_sql(subq.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement { @@ -305,7 +323,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Subquery(sub_query)) } - Expr::InSubquery(insubq) => { + Expr::InSubquery(insubq, _) => { let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; @@ -323,7 +341,7 @@ impl Unparser<'_> { negated: insubq.negated, }) } - Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists { subquery, negated }, _) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement { @@ -338,45 +356,45 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::IsNull(expr) => { + Expr::IsNull(expr, _) => { Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + Expr::IsNotNull(expr, _) => Ok(ast::Expr::IsNotNull(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsTrue(expr) => { + Expr::IsTrue(expr, _) => { Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + Expr::IsNotTrue(expr, _) => Ok(ast::Expr::IsNotTrue(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsFalse(expr) => { + Expr::IsFalse(expr, _) => { Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) } - Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + Expr::IsNotFalse(expr, _) => Ok(ast::Expr::IsNotFalse(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + Expr::IsUnknown(expr, _) => Ok(ast::Expr::IsUnknown(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + Expr::IsNotUnknown(expr, _) => Ok(ast::Expr::IsNotUnknown(Box::new( self.expr_to_sql_inner(expr)?, ))), - Expr::Not(expr) => { + Expr::Not(expr, _) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } - Expr::Negative(expr) => { + Expr::Negative(expr, _) => { let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), }) } - Expr::ScalarVariable(_, ids) => { + Expr::ScalarVariable(_, ids, _) => { if ids.is_empty() { return internal_err!("Not a valid ScalarVariable"); } @@ -393,7 +411,7 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, data_type }, _) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, @@ -403,7 +421,7 @@ impl Unparser<'_> { }) } // TODO: unparsing wildcard addition options - Expr::Wildcard(Wildcard { qualifier, .. }) => { + Expr::Wildcard(Wildcard { qualifier, .. }, _) => { if let Some(qualifier) = qualifier { let idents: Vec = qualifier.to_vec().into_iter().map(Ident::new).collect(); @@ -412,7 +430,7 @@ impl Unparser<'_> { Ok(ast::Expr::Wildcard) } } - Expr::GroupingSet(grouping_set) => match grouping_set { + Expr::GroupingSet(grouping_set, _) => match grouping_set { GroupingSet::GroupingSets(grouping_sets) => { let expr_ast_sets = grouping_sets .iter() @@ -446,11 +464,11 @@ impl Unparser<'_> { Ok(ast::Expr::Rollup(expr_ast_sets)) } }, - Expr::Placeholder(p) => { + Expr::Placeholder(p, _) => { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } - Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), - Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::OuterReferenceColumn(_, col, _) => self.col_to_sql(col), + Expr::Unnest(unnest, _) => self.unnest_to_sql(unnest), } } @@ -526,7 +544,7 @@ impl Unparser<'_> { .chunks_exact(2) .map(|chunk| { let key = match &chunk[0] { - Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => self.new_ident_quoted_if_needs(s.to_string()), _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) }; @@ -546,7 +564,7 @@ impl Unparser<'_> { } let mut id = match &args[0] { - Expr::Column(col) => match self.col_to_sql(col)? { + Expr::Column(col, _) => match self.col_to_sql(col)? { ast::Expr::Identifier(ident) => vec![ident], ast::Expr::CompoundIdentifier(idents) => idents, other => return internal_err!("expected col_to_sql to return an Identifier or CompoundIdentifier, but received: {:?}", other), @@ -555,7 +573,7 @@ impl Unparser<'_> { }; let field = match &args[1] { - Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + Expr::Literal(lit, _) => self.new_ident_quoted_if_needs(lit.to_string()), _ => { return internal_err!( "get_field expects second argument to be a string, but received: {:?}", @@ -689,10 +707,13 @@ impl Unparser<'_> { .map(|e| { if matches!( e, - Expr::Wildcard(Wildcard { - qualifier: None, - .. - }) + Expr::Wildcard( + Wildcard { + qualifier: None, + .. + }, + _ + ) ) { Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { @@ -1666,7 +1687,7 @@ mod tests { fn expr_to_sql_ok() -> Result<()> { let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let dummy_logical_plan = table_scan(Some("t"), &dummy_schema, None)? - .project(vec![Expr::Wildcard(Wildcard { + .project(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })])? @@ -1676,7 +1697,7 @@ mod tests { let tests: Vec<(Expr, &str)> = vec![ ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), ( - Expr::Column(Column { + Expr::column(Column { relation: Some(TableReference::partial("a", "b")), name: "c".to_string(), }) @@ -1699,14 +1720,14 @@ mod tests { r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }), r#"CAST(a AS DATETIME)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Timestamp( TimeUnit::Nanosecond, @@ -1716,14 +1737,14 @@ mod tests { r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), }), r#"CAST(a AS TIMESTAMP)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::UInt32, }), @@ -1754,7 +1775,7 @@ mod tests { r#"dummy_udf(a, b) IS NOT NULL"#, ), ( - Expr::Like(Like { + Expr::_like(Like { negated: true, expr: Box::new(col("a")), pattern: Box::new(lit("foo")), @@ -1764,7 +1785,7 @@ mod tests { r#"a NOT LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::SimilarTo(Like { + Expr::similar_to(Like { negated: false, expr: Box::new(col("a")), pattern: Box::new(lit("foo")), @@ -1774,93 +1795,93 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::literal(ScalarValue::Date64(Some(0))), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::literal(ScalarValue::Date64(Some(10000))), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::literal(ScalarValue::Date64(Some(-10000))), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::literal(ScalarValue::Date32(Some(0))), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::literal(ScalarValue::Date32(Some(10))), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::literal(ScalarValue::Date32(Some(-1))), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::literal(ScalarValue::TimestampSecond(Some(10001), None)), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( + Expr::literal(ScalarValue::TimestampSecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::literal(ScalarValue::TimestampMillisecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( + Expr::literal(ScalarValue::TimestampMillisecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( + Expr::literal(ScalarValue::TimestampMicrosecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::literal(ScalarValue::TimestampNanosecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( + Expr::literal(ScalarValue::TimestampNanosecond( Some(10001), Some("+08:00".into()), )), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::literal(ScalarValue::Time32Second(Some(10001))), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::literal(ScalarValue::Time32Millisecond(Some(10001))), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::literal(ScalarValue::Time64Microsecond(Some(10001))), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::literal(ScalarValue::Time64Nanosecond(Some(10001))), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), ( count_udaf() - .call(vec![Expr::Wildcard(Wildcard { + .call(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })]) @@ -1871,7 +1892,7 @@ mod tests { ), ( count_udaf() - .call(vec![Expr::Wildcard(Wildcard { + .call(vec![Expr::wildcard(Wildcard { qualifier: None, options: WildcardOptions::default(), })]) @@ -1881,7 +1902,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::window_function(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), args: vec![col("col")], partition_by: vec![], @@ -1892,7 +1913,7 @@ mod tests { r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( - Expr::WindowFunction(WindowFunction { + Expr::window_function(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], @@ -1941,7 +1962,7 @@ mod tests { Expr::between(col("a"), lit(1), lit(7)), r#"(a BETWEEN 1 AND 7)"#, ), - (Expr::Negative(Box::new(col("a"))), r#"-a"#), + (Expr::negative(Box::new(col("a"))), r#"-a"#), ( exists(Arc::new(dummy_logical_plan.clone())), r#"EXISTS (SELECT * FROM t WHERE (t.a = 1))"#, @@ -1959,11 +1980,11 @@ mod tests { r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - Expr::ScalarVariable(Int8, vec![String::from("@a")]), + Expr::scalar_variable(Int8, vec![String::from("@a")]), r#"@a"#, ), ( - Expr::ScalarVariable( + Expr::scalar_variable( Int8, vec![String::from("@root"), String::from("foo")], ), @@ -1989,7 +2010,7 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( + (col("a") + col("b")).gt(Expr::literal(ScalarValue::Decimal128( Some(100123), 28, 3, @@ -1997,7 +2018,7 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( + (col("a") + col("b")).gt(Expr::literal(ScalarValue::Decimal256( Some(100123.into()), 28, 3, @@ -2005,15 +2026,15 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Decimal128(10, -2), }), r#"CAST(a AS DECIMAL(12,0))"#, ), ( - Expr::Unnest(Unnest { - expr: Box::new(Expr::Column(Column { + Expr::unnest(Unnest { + expr: Box::new(Expr::column(Column { relation: Some(TableReference::partial("schema", "table")), name: "array_col".to_string(), })), @@ -2092,7 +2113,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }); @@ -2117,7 +2138,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Float64, }); @@ -2321,11 +2342,11 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::literal(ScalarValue::Float64(Some(3f64))), "3.0"), + (Expr::literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), + (Expr::literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + Expr::literal(ScalarValue::Float32(Some(-2.989f32))), "-2.989", ), ]; @@ -2344,8 +2365,8 @@ mod tests { fn test_cast_value_to_binary_expr() { let tests = [ ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + Expr::cast(Cast { + expr: Box::new(Expr::literal(ScalarValue::Utf8(Some( "blah".to_string(), )))), data_type: DataType::Binary, @@ -2353,8 +2374,8 @@ mod tests { "'blah'", ), ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + Expr::cast(Cast { + expr: Box::new(Expr::literal(ScalarValue::Utf8(Some( "blah".to_string(), )))), data_type: DataType::BinaryView, @@ -2389,7 +2410,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type, }); @@ -2448,7 +2469,7 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![Expr::literal(ScalarValue::new_utf8(unit)), col("x")]); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{}", ast); @@ -2472,7 +2493,7 @@ mod tests { [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Int64, }); @@ -2500,7 +2521,7 @@ mod tests { [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Int32, }); @@ -2539,7 +2560,7 @@ mod tests { (&mysql_dialect, ×tamp_with_tz, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type: data_type.clone(), }); @@ -2566,7 +2587,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { + let expr = Expr::cast(Cast { expr: Box::new(col("a")), data_type, }); @@ -2583,8 +2604,8 @@ mod tests { #[test] fn test_cast_value_to_dict_expr() { let tests = [( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + Expr::cast(Cast { + expr: Box::new(Expr::literal(ScalarValue::Utf8(Some( "variation".to_string(), )))), data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), @@ -2615,16 +2636,16 @@ mod tests { [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] { let unparser = Unparser::new(dialect.as_ref()); - let expr = Expr::ScalarFunction(ScalarFunction { + let expr = Expr::scalar_function(ScalarFunction { func: Arc::new(ScalarUDF::from( datafusion_functions::math::round::RoundFunc::new(), )), args: vec![ - Expr::Cast(Cast { + Expr::cast(Cast { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::literal(ScalarValue::Int64(Some(2))), ], }); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 81e47ed939f22..83916838badbf 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -91,31 +91,31 @@ impl Unparser<'_> { let plan = normalize_union_schema(plan)?; match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Window(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Join(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), - LogicalPlan::Dml(_) => self.dml_to_sql(&plan), - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Aggregate(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::Join(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Union(_, _) + | LogicalPlan::TableScan(_, _) + | LogicalPlan::EmptyRelation(_, _) + | LogicalPlan::Subquery(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Limit(_, _) + | LogicalPlan::Statement(_, _) + | LogicalPlan::Values(_, _) + | LogicalPlan::Distinct(_, _) => self.select_to_sql_statement(&plan), + LogicalPlan::Dml(_, _) => self.dml_to_sql(&plan), + LogicalPlan::Explain(_, _) + | LogicalPlan::Analyze(_, _) + | LogicalPlan::Extension(_, _) + | LogicalPlan::Ddl(_, _) + | LogicalPlan::Copy(_, _) + | LogicalPlan::DescribeTable(_, _) + | LogicalPlan::RecursiveQuery(_, _) + | LogicalPlan::Unnest(_, _) => not_impl_err!("Unsupported plan: {plan:?}"), } } @@ -275,7 +275,7 @@ impl Unparser<'_> { relation: &mut RelationBuilder, ) -> Result<()> { match plan { - LogicalPlan::TableScan(scan) => { + LogicalPlan::TableScan(scan, _) => { if let Some(unparsed_table_scan) = Self::unparse_table_scan_pushdown(plan, None)? { @@ -304,7 +304,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Projection(p) => { + LogicalPlan::Projection(p, _) => { if let Some(new_plan) = rewrite_plan_for_sort_on_non_projected_fields(p) { return self .select_to_sql_recursively(&new_plan, query, select, relation); @@ -321,7 +321,7 @@ impl Unparser<'_> { self.reconstruct_select_statement(plan, p, select)?; self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { if let Some(agg) = find_agg_node_within_select(plan, select.already_projected()) { @@ -341,7 +341,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { // Limit can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -378,7 +378,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { // Sort can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -419,7 +419,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { // Aggregation can be already handled in the projection case if !select.already_projected() { // The query returns aggregate and group expressions. If that weren't the case, @@ -448,7 +448,7 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(distinct, _) => { // Distinct can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -486,7 +486,7 @@ impl Unparser<'_> { select.distinct(Some(select_distinct)); self.select_to_sql_recursively(input, query, select, relation) } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let mut table_scan_filters = vec![]; let left_plan = @@ -529,7 +529,7 @@ impl Unparser<'_> { // Combine `table_scan_filters` into a single filter using `AND` let Some(combined_filters) = table_scan_filters.into_iter().reduce(|acc, filter| { - Expr::BinaryExpr(BinaryExpr { + Expr::binary_expr(BinaryExpr { left: Box::new(acc), op: Operator::And, right: Box::new(filter), @@ -541,7 +541,7 @@ impl Unparser<'_> { // Combine `join.filter` with `combined_filters` using `AND` match &join.filter { - Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + Some(filter) => Some(Expr::binary_expr(BinaryExpr { left: Box::new(filter.clone()), op: Operator::And, right: Box::new(combined_filters), @@ -579,7 +579,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::SubqueryAlias(plan_alias) => { + LogicalPlan::SubqueryAlias(plan_alias, _) => { let (plan, mut columns) = subquery_alias_inner_query_and_columns(plan_alias); let unparsed_table_scan = Self::unparse_table_scan_pushdown( @@ -626,7 +626,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { if union.inputs.len() != 2 { return not_impl_err!( "UNION ALL expected 2 inputs, but found {}", @@ -665,7 +665,7 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { // Window nodes are handled simultaneously with Projection nodes self.select_to_sql_recursively( window.input.as_ref(), @@ -674,12 +674,14 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::EmptyRelation(_) => { + LogicalPlan::EmptyRelation(_, _) => { relation.empty(); Ok(()) } - LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), - LogicalPlan::Unnest(unnest) => { + LogicalPlan::Extension(_, _) => { + not_impl_err!("Unsupported operator: {plan:?}") + } + LogicalPlan::Unnest(unnest, _) => { if !unnest.struct_type_columns.is_empty() { return internal_err!( "Struct type columns are not currently supported in UNNEST: {:?}", @@ -694,7 +696,7 @@ impl Unparser<'_> { // | Projection: table.col1, table.col2 AS UNNEST(table.col2) // | Filter: table.col3 = Int64(3) // | TableScan: table projection=None - if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + if let LogicalPlan::Projection(p, _) = unnest.input.as_ref() { // continue with projection input self.select_to_sql_recursively(&p.input, query, select, relation) } else { @@ -716,7 +718,7 @@ impl Unparser<'_> { alias: Option, ) -> Result> { match plan { - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { if !Self::is_scan_with_pushdown(table_scan) { return Ok(None); } @@ -801,7 +803,7 @@ impl Unparser<'_> { Ok(Some(builder.build()?)) } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { Self::unparse_table_scan_pushdown( &subquery_alias.input, Some(subquery_alias.alias.clone()), @@ -809,7 +811,7 @@ impl Unparser<'_> { } // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. // The inner table scan could be a scan with pushdown operations. - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(projection, _) => { if let Some(plan) = Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? { @@ -847,7 +849,7 @@ impl Unparser<'_> { fn select_item_to_sql(&self, expr: &Expr) -> Result { match expr { - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { expr, name, .. }, _) => { let inner = self.expr_to_sql(expr)?; Ok(ast::SelectItem::ExprWithAlias { @@ -908,14 +910,20 @@ impl Unparser<'_> { for (left, right) in join_conditions { match (left, right) { ( - Expr::Column(Column { - relation: _, - name: left_name, - }), - Expr::Column(Column { - relation: _, - name: right_name, - }), + Expr::Column( + Column { + relation: _, + name: left_name, + }, + _, + ), + Expr::Column( + Column { + relation: _, + name: right_name, + }, + _, + ), ) if left_name == right_name => { idents.push(self.new_ident_quoted_if_needs(left_name.to_string())); } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 68af121a41179..d3a8a1c88a66b 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -23,7 +23,6 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, HashMap, Result, TableReference, }; -use datafusion_expr::expr::Alias; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -58,20 +57,20 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result let plan = plan.clone(); let transformed_plan = plan.transform_up(|plan| match plan { - LogicalPlan::Union(mut union) => { + LogicalPlan::Union(mut union, _) => { let schema = Arc::unwrap_or_clone(union.schema); let schema = schema.strip_qualifiers(); union.schema = Arc::new(schema); - Ok(Transformed::yes(LogicalPlan::Union(union))) + Ok(Transformed::yes(LogicalPlan::union(union))) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { // Only rewrite Sort expressions that have a UNION as their input - if !matches!(&*sort.input, LogicalPlan::Union(_)) { - return Ok(Transformed::no(LogicalPlan::Sort(sort))); + if !matches!(&*sort.input, LogicalPlan::Union(_, _)) { + return Ok(Transformed::no(LogicalPlan::sort(sort))); } - Ok(Transformed::yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::sort(Sort { expr: rewrite_sort_expr_for_union(sort.expr)?, input: sort.input, fetch: sort.fetch, @@ -87,9 +86,9 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { let sort_exprs = exprs .map_elements(&mut |expr: Expr| { expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { + if let Expr::Column(mut col, _) = expr { col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) + Ok(Transformed::yes(Expr::column(col))) } else { Ok(Transformed::no(expr)) } @@ -122,11 +121,11 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( p: &Projection, ) -> Option { - let LogicalPlan::Sort(sort) = p.input.as_ref() else { + let LogicalPlan::Sort(sort, _) = p.input.as_ref() else { return None; }; - let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else { + let LogicalPlan::Projection(inner_p, _) = sort.input.as_ref() else { return None; }; @@ -136,20 +135,20 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .iter() .enumerate() .map(|(i, f)| match f { - Expr::Alias(alias) => { - let a = Expr::Column(alias.name.clone().into()); + Expr::Alias(alias, _) => { + let a = Expr::column(alias.name.clone().into()); map.insert(a.clone(), f.clone()); a } - Expr::Column(_) => { + Expr::Column(_, _) => { map.insert( - Expr::Column(inner_p.schema.field(i).name().into()), + Expr::column(inner_p.schema.field(i).name().into()), f.clone(), ); f.clone() } _ => { - let a = Expr::Column(inner_p.schema.field(i).name().into()); + let a = Expr::column(inner_p.schema.field(i).name().into()); map.insert(a.clone(), f.clone()); a } @@ -182,9 +181,9 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .collect::>(); inner_p.expr.clone_from(&new_exprs); - sort.input = Arc::new(LogicalPlan::Projection(inner_p)); + sort.input = Arc::new(LogicalPlan::projection(inner_p)); - Some(LogicalPlan::Sort(sort)) + Some(LogicalPlan::sort(sort)) } else { None } @@ -222,7 +221,7 @@ pub(super) fn subquery_alias_inner_query_and_columns( ) -> (&LogicalPlan, Vec) { let plan: &LogicalPlan = subquery_alias.input.as_ref(); - let LogicalPlan::Projection(outer_projections) = plan else { + let LogicalPlan::Projection(outer_projections, _) = plan else { return (plan, vec![]); }; @@ -236,14 +235,14 @@ pub(super) fn subquery_alias_inner_query_and_columns( // Projection: j1.j1_id AS id // Projection: j1.j1_id for (i, inner_expr) in inner_projection.expr.iter().enumerate() { - let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + let Expr::Alias(ref outer_alias, _) = &outer_projections.expr[i] else { return (plan, vec![]); }; // Inner projection schema fields store the projection name which is used in outer // projection expr let inner_expr_string = match inner_expr { - Expr::Column(_) => inner_expr.to_string(), + Expr::Column(_, _) => inner_expr.to_string(), _ => inner_projection.schema.field(i).name().clone(), }; @@ -270,11 +269,13 @@ pub(super) fn inject_column_aliases_into_subquery( aliases: Vec, ) -> Result { match &plan { - LogicalPlan::Projection(inner_p) => Ok(inject_column_aliases(inner_p, aliases)), + LogicalPlan::Projection(inner_p, _) => { + Ok(inject_column_aliases(inner_p, aliases)) + } _ => { // projection is wrapped by other operator (LIMIT, SORT, etc), iterate through the plan to find it plan.map_children(|child| { - if let LogicalPlan::Projection(p) = &child { + if let LogicalPlan::Projection(p, _) = &child { Ok(Transformed::yes(inject_column_aliases(p, aliases.clone()))) } else { Ok(Transformed::no(child)) @@ -303,29 +304,25 @@ pub(super) fn inject_column_aliases( .zip(aliases) .map(|(expr, col_alias)| { let relation = match &expr { - Expr::Column(col) => col.relation.clone(), + Expr::Column(col, _) => col.relation.clone(), _ => None, }; - Expr::Alias(Alias { - expr: Box::new(expr.clone()), - relation, - name: col_alias.value, - }) + expr.clone().alias_qualified(relation, col_alias.value) }) .collect::>(); updated_projection.expr = new_exprs; - LogicalPlan::Projection(updated_projection) + LogicalPlan::projection(updated_projection) } fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { match logical_plan { - LogicalPlan::Projection(p) => Some(p), - LogicalPlan::Limit(p) => find_projection(p.input.as_ref()), - LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()), - LogicalPlan::Sort(p) => find_projection(p.input.as_ref()), + LogicalPlan::Projection(p, _) => Some(p), + LogicalPlan::Limit(p, _) => find_projection(p.input.as_ref()), + LogicalPlan::Distinct(p, _) => find_projection(p.input().as_ref()), + LogicalPlan::Sort(p, _) => find_projection(p.input.as_ref()), _ => None, } } @@ -352,13 +349,13 @@ impl TreeNodeRewriter for TableAliasRewriter<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::Column(column) => { + Expr::Column(column, _) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); - Ok(Transformed::yes(Expr::Column(new_column))) + Ok(Transformed::yes(Expr::column(new_column))) } else { - Ok(Transformed::no(Expr::Column(column))) + Ok(Transformed::no(Expr::column(column))) } } _ => Ok(Transformed::no(expr)), diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index d0f80da83d63f..56dc1dc21bb43 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -50,11 +50,11 @@ pub(crate) fn find_agg_node_within_select( input.first()? }; // Agg nodes explicitly return immediately with a single node - if let LogicalPlan::Aggregate(agg) = input { + if let LogicalPlan::Aggregate(agg, _) = input { Some(agg) - } else if let LogicalPlan::TableScan(_) = input { + } else if let LogicalPlan::TableScan(_, _) = input { None - } else if let LogicalPlan::Projection(_) = input { + } else if let LogicalPlan::Projection(_, _) = input { if already_projected { None } else { @@ -76,11 +76,11 @@ pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unne input.first()? }; - if let LogicalPlan::Unnest(unnest) = input { + if let LogicalPlan::Unnest(unnest, _) = input { Some(unnest) - } else if let LogicalPlan::TableScan(_) = input { + } else if let LogicalPlan::TableScan(_, _) = input { None - } else if let LogicalPlan::Projection(_) = input { + } else if let LogicalPlan::Projection(_, _) = input { None } else { find_unnest_node_within_select(input) @@ -107,7 +107,7 @@ pub(crate) fn find_window_nodes_within_select<'a>( // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection match input { - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { prev_windows = match &mut prev_windows { Some(windows) => { windows.push(window); @@ -117,14 +117,14 @@ pub(crate) fn find_window_nodes_within_select<'a>( }; find_window_nodes_within_select(input, prev_windows, already_projected) } - LogicalPlan::Projection(_) => { + LogicalPlan::Projection(_, _) => { if already_projected { prev_windows } else { find_window_nodes_within_select(input, prev_windows, true) } } - LogicalPlan::TableScan(_) => prev_windows, + LogicalPlan::TableScan(_, _) => prev_windows, _ => find_window_nodes_within_select(input, prev_windows, already_projected), } } @@ -135,14 +135,14 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { expr.transform(|sub_expr| { - if let Expr::Column(col_ref) = &sub_expr { + if let Expr::Column(col_ref, _) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { - if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let LogicalPlan::Projection(Projection { expr, .. }, _) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { - let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + let unnest_expr = Expr::unnest(expr::Unnest::new(unprojected_expr.clone())); return Ok(Transformed::yes(unnest_expr)); } } @@ -169,7 +169,7 @@ pub(crate) fn unproject_agg_exprs( windows: Option<&[&Window]>, ) -> Result { expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { + if let Expr::Column(c, _) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) } else if let Some(unprojected_expr) = @@ -196,11 +196,11 @@ pub(crate) fn unproject_agg_exprs( /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { + if let Expr::Column(c, _) = sub_expr { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { - Ok(Transformed::no(Expr::Column(c))) + Ok(Transformed::no(Expr::column(c))) } } else { Ok(Transformed::no(sub_expr)) @@ -211,7 +211,7 @@ pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result< fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { if let Ok(index) = agg.schema.index_of_column(column) { - if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_, _)]) { // For grouping set expr, we must operate by expression list from the grouping set let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; match index.cmp(&grouping_expr.len()) { @@ -255,11 +255,11 @@ pub(crate) fn unproject_sort_expr( let mut sort_expr = sort_expr.clone(); // Remove alias if present, because ORDER BY cannot use aliases - if let Expr::Alias(alias) = &sort_expr.expr { + if let Expr::Alias(alias, _) = &sort_expr.expr { sort_expr.expr = *alias.expr.clone(); } - let Expr::Column(ref col_ref) = sort_expr.expr else { + let Expr::Column(ref col_ref, _) = sort_expr.expr else { return Ok(sort_expr); }; @@ -279,10 +279,10 @@ pub(crate) fn unproject_sort_expr( // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need // to transform it back to the actual expression. - if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { + if let LogicalPlan::Projection(Projection { expr, schema, .. }, _) = input { if let Ok(idx) = schema.index_of_column(col_ref) { - if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { - sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + if let Some(Expr::ScalarFunction(scalar_fn, _)) = expr.get(idx) { + sort_expr.expr = Expr::scalar_function(scalar_fn.clone()); } } return Ok(sort_expr); @@ -316,15 +316,15 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( while let Some(current_plan) = plan_stack.pop() { match current_plan { - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { table_alias = Some(alias.alias.clone()); plan_stack.push(alias.input.as_ref()); } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { filters.push(filter.predicate.clone()); plan_stack.push(filter.input.as_ref()); } - LogicalPlan::TableScan(table_scan) => { + LogicalPlan::TableScan(table_scan, _) => { let table_schema = table_scan.source.schema(); // optional rewriter if table has an alias let mut filter_alias_rewriter = @@ -381,7 +381,7 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -402,7 +402,7 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index e479bdbacd839..74bca94c84f41 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -43,10 +43,10 @@ pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() .transform_up(|nested_expr| { match nested_expr { - Expr::Column(col) => { + Expr::Column(col, _) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; - Ok(Transformed::yes(Expr::Column(Column::from(( + Ok(Transformed::yes(Expr::column(Column::from(( qualifier, field, ))))) } @@ -97,23 +97,23 @@ pub(crate) fn check_columns_satisfy_exprs( message_prefix: &str, ) -> Result<()> { columns.iter().try_for_each(|c| match c { - Expr::Column(_) => Ok(()), + Expr::Column(_, _) => Ok(()), _ => internal_err!("Expr::Column are required"), })?; let column_exprs = find_column_exprs(exprs); for e in &column_exprs { match e { - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + Expr::GroupingSet(GroupingSet::Rollup(exprs), _) => { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; } } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + Expr::GroupingSet(GroupingSet::Cube(exprs), _) => { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; } } - Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs), _) => { for exprs in lists_of_exprs { for e in exprs { check_column_satisfies_expr(columns, e, message_prefix)?; @@ -148,7 +148,9 @@ pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { exprs .iter() .filter_map(|expr| match expr { - Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())), + Expr::Alias(Alias { expr, name, .. }, _) => { + Some((name.clone(), *expr.clone())) + } _ => None, }) .collect::>() @@ -165,17 +167,17 @@ pub(crate) fn resolve_positions_to_exprs( match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) + Expr::Literal(ScalarValue::Int64(Some(position)), _) if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; let select_expr = &select_exprs[index]; Ok(match select_expr { - Expr::Alias(Alias { expr, .. }) => *expr.clone(), + Expr::Alias(Alias { expr, .. }, _) => *expr.clone(), _ => select_expr.clone(), }) } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!( "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", position, select_exprs.len() ), @@ -190,11 +192,11 @@ pub(crate) fn resolve_aliases_to_exprs( aliases: &HashMap, ) -> Result { expr.transform_up(|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { + Expr::Column(c, _) if c.relation.is_none() => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { - Ok(Transformed::no(Expr::Column(c))) + Ok(Transformed::no(Expr::column(c))) } } _ => Ok(Transformed::no(nested_expr)), @@ -208,9 +210,11 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => { + Expr::WindowFunction(WindowFunction { partition_by, .. }, _) => { + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }, _) => match expr.as_ref() { + Expr::WindowFunction(WindowFunction { partition_by, .. }, _) => { Ok(partition_by) } expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -382,7 +386,7 @@ impl<'a> RecursiveUnnestRewriter<'a> { Ok( get_struct_unnested_columns(&placeholder_name, &inner_fields) .into_iter() - .map(Expr::Column) + .map(Expr::column) .collect(), ) } @@ -424,7 +428,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** fn f_down(&mut self, expr: Expr) -> Result> { - if let Expr::Unnest(ref unnest_expr) = expr { + if let Expr::Unnest(ref unnest_expr, _) = expr { let (data_type, _) = unnest_expr.expr.data_type_and_nullable(self.input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); @@ -481,7 +485,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { /// ``` /// fn f_up(&mut self, expr: Expr) -> Result> { - if let Expr::Unnest(ref traversing_unnest) = expr { + if let Expr::Unnest(ref traversing_unnest, _) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; } @@ -534,7 +538,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { // retain their projection // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b // this condition can be checked by maintaining an Option - if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() { + if matches!(&expr, Expr::Column(_, _)) && self.top_most_unnest.is_none() { push_projection_dedupl(self.inner_projection_exprs, expr.clone()); } @@ -592,7 +596,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( } = original_expr.clone().rewrite(&mut rewriter)?; if !transformed { - if matches!(&transformed_expr, Expr::Column(_)) + if matches!(&transformed_expr, Expr::Column(_, _)) || matches!(&transformed_expr, Expr::Wildcard { .. }) { push_projection_dedupl(inner_projection_exprs, transformed_expr.clone()); @@ -602,7 +606,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( // outer projection just select its name let column_name = transformed_expr.schema_name().to_string(); push_projection_dedupl(inner_projection_exprs, transformed_expr); - Ok(vec![Expr::Column(Column::from_name(column_name))]) + Ok(vec![Expr::column(Column::from_name(column_name))]) } } else { if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs { @@ -671,7 +675,7 @@ mod tests { let dfschema = DFSchema::try_from(schema)?; - let input = LogicalPlan::EmptyRelation(EmptyRelation { + let input = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(dfschema), }); @@ -775,7 +779,7 @@ mod tests { let dfschema = DFSchema::try_from(schema)?; - let input = LogicalPlan::EmptyRelation(EmptyRelation { + let input = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(dfschema), }); @@ -887,7 +891,7 @@ mod tests { let dfschema = DFSchema::try_from(schema)?; - let input = LogicalPlan::EmptyRelation(EmptyRelation { + let input = LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: Arc::new(dfschema), }); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ab7e6c8d0bb73..914526a9f996b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -206,10 +206,14 @@ fn test_parse_options_value_normalization() { assert_eq!(expected_plan, format!("{plan}")); match plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { options, .. }, - )) - | LogicalPlan::Copy(CopyTo { options, .. }) => { + LogicalPlan::Ddl( + DdlStatement::CreateExternalTable(CreateExternalTable { + options, + .. + }), + _, + ) + | LogicalPlan::Copy(CopyTo { options, .. }, _) => { expected_options.iter().for_each(|(k, v)| { assert_eq!(Some(&v.to_string()), options.get(*k)); }); @@ -2711,7 +2715,7 @@ fn prepare_stmt_quick_test( assert_eq!(format!("{assert_plan}"), expected_plan); // verify data types - if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) = + if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. }), _) = assert_plan { let dt = format!("{data_types:?}"); @@ -4436,15 +4440,18 @@ fn plan_create_index() { "CREATE UNIQUE INDEX IF NOT EXISTS idx_name ON test USING btree (name, age DESC)"; let plan = logical_plan_with_options(sql, ParserOptions::default()).unwrap(); match plan { - LogicalPlan::Ddl(DdlStatement::CreateIndex(CreateIndex { - name, - table, - using, - columns, - unique, - if_not_exists, - .. - })) => { + LogicalPlan::Ddl( + DdlStatement::CreateIndex(CreateIndex { + name, + table, + using, + columns, + unique, + if_not_exists, + .. + }), + _, + ) => { assert_eq!(name, Some("idx_name".to_string())); assert_eq!(format!("{table}"), "test"); assert_eq!(using, Some("btree".to_string())); diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 77b10b41ccb3d..da7da7608d8c5 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\), LogicalPlanStats \{ patterns: EnumSet\(ExprLiteral\) \}\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index f3fee4f1fca63..9322ea7926e47 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -429,7 +429,7 @@ logical_plan 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t2 06)--TableScan: t1 projection=[a] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: , negated: false }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: , negated: false }, LogicalPlanStats { patterns: EnumSet(ExprColumn | ExprLiteral | ExprAggregateFunction) }) statement ok drop table t1; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index e636e93007a4a..37a9bc46d8c7e 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4062,12 +4062,12 @@ logical_plan 07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] 08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }, LogicalPlanStats { patterns: EnumSet() }) # Test CROSS JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}, LogicalPlanStats \{ patterns: EnumSet\(\) \}\) select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); @@ -4085,12 +4085,12 @@ logical_plan 07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] 08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }, LogicalPlanStats { patterns: EnumSet() }) # Test INNER JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}, LogicalPlanStats \{ patterns: EnumSet\(\) \}\) select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); # Test RIGHT JOIN LATERAL syntax (unsupported) diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 0f9582b04c589..c72632e502d81 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -60,7 +60,7 @@ logical_plan 06)----------Filter: outer_ref(t1.a) = t2.a 07)------------TableScan: t2 08)----TableScan: t1 -physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery(, LogicalPlanStats { patterns: EnumSet(ExprColumn | ExprBinaryExpr | ExprAggregateFunction) }) # set from other table query TT diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index c81fa3e29e1cf..3fc9f8600cb57 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -166,7 +166,7 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( for expr in exprs { #[allow(clippy::collapsible_match)] match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { + Expr::BinaryExpr(binary_expr, _) => match binary_expr { x @ (BinaryExpr { left, op: Operator::Eq, @@ -184,7 +184,7 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( }; match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { + (Expr::Column(l, _), Expr::Column(r, _)) => { accum_join_keys.push((l.clone(), r.clone())); } _ => accum_filters.push(expr.clone()), @@ -292,16 +292,16 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), - LogicalPlan::Aggregate(a) => { + LogicalPlan::Projection(p, _) => Ok(LogicalPlan::projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a, _) => { let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + Ok(LogicalPlan::aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) } } }, @@ -437,7 +437,7 @@ fn rename_expressions( .map(|(old_expr, new_field)| { // Check if type (i.e. nested struct field names) match, use Cast to rename if needed let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( + Expr::cast(Cast::new( Box::new(old_expr), new_field.data_type().to_owned(), )) @@ -447,7 +447,7 @@ fn rename_expressions( // Alias column if needed to fix the top-level name match &new_expr { // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier - Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + Expr::Column(c, _) if &c.name == new_field.name() => Ok(new_expr), _ => new_expr.alias_if_changed(new_field.name().to_owned()), } }) @@ -594,7 +594,7 @@ pub async fn from_substrait_rel( ) .await?; // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { + if let Expr::WindowFunction(_, _) = &e { // Adding the same expression here and in the project below // works because the project's builder uses columnize_expr(..) // to transform it into a column reference @@ -605,7 +605,7 @@ pub async fn from_substrait_rel( let mut final_exprs: Vec = vec![]; for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from( + let e = Expr::column(Column::from( original_schema.qualified_field(index), )); final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); @@ -710,7 +710,7 @@ pub async fn from_substrait_rel( // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when // parsed by the producer and consumer, since Substrait does not have a type dedicated // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( + group_exprs.push(Expr::grouping_set(GroupingSet::GroupingSets( grouping_sets, ))); } @@ -912,7 +912,7 @@ pub async fn from_substrait_rel( } Some(ReadType::VirtualTable(vt)) => { if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + return Ok(LogicalPlan::empty_relation(EmptyRelation { produce_one_row: false, schema: DFSchemaRef::new(substrait_schema), })); @@ -928,7 +928,7 @@ pub async fn from_substrait_rel( .iter() .map(|lit| { name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( + Ok(Expr::literal(from_substrait_literal( lit, extensions, &named_struct.names, @@ -947,7 +947,7 @@ pub async fn from_substrait_rel( }) .collect::>()?; - Ok(LogicalPlan::Values(Values { + Ok(LogicalPlan::values(Values { schema: DFSchemaRef::new(substrait_schema), values, })) @@ -1044,7 +1044,7 @@ pub async fn from_substrait_rel( let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -1061,7 +1061,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(state, input_rel, extensions).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -1076,7 +1076,7 @@ pub async fn from_substrait_rel( inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Ok(LogicalPlan::extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { let Some(input) = exchange.input.as_ref() else { @@ -1112,7 +1112,7 @@ pub async fn from_substrait_rel( return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); } }; - Ok(LogicalPlan::Repartition(Repartition { + Ok(LogicalPlan::repartition(Repartition { input, partitioning_scheme, })) @@ -1180,7 +1180,7 @@ fn apply_emit_kind( // expressions in the projection are volatile. This is to avoid issues like // converting a single call of the random() function into multiple calls due to // duplicate fields in the output_mapping. - LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + LogicalPlan::Projection(proj, _) if !contains_volatile_expr(&proj) => { let mut exprs: Vec = vec![]; for field in output_mapping { let expr = proj.expr @@ -1201,7 +1201,7 @@ fn apply_emit_kind( let mut exprs: Vec = vec![]; for index in output_mapping.into_iter() { - let column = Expr::Column(Column::from( + let column = Expr::column(Column::from( input_schema.qualified_field(index as usize), )); let expr = name_tracker.get_uniquely_named_expr(column)?; @@ -1290,7 +1290,7 @@ fn apply_projection( let df_schema = df_schema.to_owned(); match plan { - LogicalPlan::TableScan(mut scan) => { + LogicalPlan::TableScan(mut scan, _) => { let column_indices: Vec = substrait_schema .strip_qualifiers() .fields() @@ -1314,7 +1314,7 @@ fn apply_projection( )?); scan.projection = Some(column_indices); - Ok(LogicalPlan::TableScan(scan)) + Ok(LogicalPlan::table_scan(scan)) } _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), } @@ -1525,12 +1525,12 @@ pub async fn from_substrait_agg_func( if let Ok(fun) = state.udaf(function_name) { // deal with situation that count(*) got no arguments let args = if fun.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + vec![Expr::literal(ScalarValue::Int64(Some(1)))] } else { args }; - Ok(Arc::new(Expr::AggregateFunction( + Ok(Arc::new(Expr::aggregate_function( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else { @@ -1554,7 +1554,7 @@ pub async fn from_substrait_rex( Some(RexType::SingularOrList(s)) => { let substrait_expr = s.value.as_ref().unwrap(); let substrait_list = s.options.as_ref(); - Ok(Expr::InList(InList { + Ok(Expr::_in_list(InList { expr: Box::new( from_substrait_rex(state, substrait_expr, input_schema, extensions) .await?, @@ -1621,7 +1621,7 @@ pub async fn from_substrait_rex( )), None => None, }; - Ok(Expr::Case(Case { + Ok(Expr::case(Case { expr, when_then_expr, else_expr, @@ -1643,7 +1643,7 @@ pub async fn from_substrait_rex( // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions if let Ok(func) = state.udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Ok(Expr::scalar_function(expr::ScalarFunction::new_udf( func.to_owned(), args, ))) @@ -1660,7 +1660,7 @@ pub async fn from_substrait_rex( .into_iter() .fold(None, |combined_expr: Option, arg: Expr| { Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { + Some(expr) => Expr::binary_expr(BinaryExpr { left: Box::new(expr), op, right: Box::new(arg), @@ -1679,10 +1679,10 @@ pub async fn from_substrait_rex( } Some(RexType::Literal(lit)) => { let scalar_value = from_substrait_literal_without_names(lit, extensions)?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => Ok(Expr::Cast(Cast::new( + Some(output_type) => Ok(Expr::cast(Cast::new( Box::new( from_substrait_rex( state, @@ -1740,7 +1740,7 @@ pub async fn from_substrait_rex( } } }; - Ok(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::window_function(expr::WindowFunction { fun, args: from_substrait_func_args( state, @@ -1778,7 +1778,7 @@ pub async fn from_substrait_rex( from_substrait_rel(state, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { + Ok(Expr::in_subquery(InSubquery { expr: Box::new( from_substrait_rex( state, @@ -1807,7 +1807,7 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { + Ok(Expr::scalar_subquery(Subquery { subquery: Arc::new(plan), outer_ref_columns, })) @@ -1824,7 +1824,7 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( + Ok(Expr::exists(Exists::new( Subquery { subquery: Arc::new(plan), outer_ref_columns, @@ -2831,7 +2831,7 @@ fn from_substrait_field_reference( Some(_) => not_impl_err!( "Direct reference StructField with child is not supported" ), - None => Ok(Expr::Column(Column::from( + None => Ok(Expr::column(Column::from( input_schema.qualified_field(x.field as usize), ))), }, @@ -2910,16 +2910,16 @@ impl BuiltinExprBuilder { let arg = Box::new(arg); let expr = match fn_name { - "not" => Expr::Not(arg), - "negative" | "negate" => Expr::Negative(arg), - "is_null" => Expr::IsNull(arg), - "is_not_null" => Expr::IsNotNull(arg), - "is_true" => Expr::IsTrue(arg), - "is_false" => Expr::IsFalse(arg), - "is_not_true" => Expr::IsNotTrue(arg), - "is_not_false" => Expr::IsNotFalse(arg), - "is_unknown" => Expr::IsUnknown(arg), - "is_not_unknown" => Expr::IsNotUnknown(arg), + "not" => Expr::_not(arg), + "negative" | "negate" => Expr::negative(arg), + "is_null" => Expr::_is_null(arg), + "is_not_null" => Expr::_is_not_null(arg), + "is_true" => Expr::_is_true(arg), + "is_false" => Expr::_is_false(arg), + "is_not_true" => Expr::_is_not_true(arg), + "is_not_false" => Expr::_is_not_false(arg), + "is_unknown" => Expr::_is_unknown(arg), + "is_not_unknown" => Expr::_is_not_unknown(arg), _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), }; @@ -2966,7 +2966,7 @@ impl BuiltinExprBuilder { .await?; match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => { // Convert Option to Option escape_char_string.and_then(|s| s.chars().next()) } @@ -2980,7 +2980,7 @@ impl BuiltinExprBuilder { None }; - Ok(Expr::Like(Like { + Ok(Expr::_like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 29019dfd74f32..a877b42b34c57 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -192,7 +192,7 @@ pub fn to_substrait_rel( extensions: &mut Extensions, ) -> Result> { match plan { - LogicalPlan::TableScan(scan) => { + LogicalPlan::TableScan(scan, _) => { let projection = scan.projection.as_ref().map(|p| { p.iter() .map(|i| StructItem { @@ -225,7 +225,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::EmptyRelation(e) => { + LogicalPlan::EmptyRelation(e, _) => { if e.produce_one_row { return not_impl_err!( "Producing a row from empty relation is unsupported" @@ -246,7 +246,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Values(v) => { + LogicalPlan::Values(v, _) => { let values = v .values .iter() @@ -254,10 +254,10 @@ pub fn to_substrait_rel( let fields = row .iter() .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv, extensions), - Expr::Alias(alias) => match alias.expr.as_ref() { + Expr::Literal(sv, _) => to_substrait_literal(sv, extensions), + Expr::Alias(alias, _) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv, extensions), + Expr::Literal(sv, _) => to_substrait_literal(sv, extensions), _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), @@ -285,7 +285,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Projection(p) => { + LogicalPlan::Projection(p, _) => { let expressions = p .expr .iter() @@ -311,7 +311,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Filter(filter) => { + LogicalPlan::Filter(filter, _) => { let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; let filter_expr = to_substrait_rex( state, @@ -329,7 +329,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Limit(limit) => { + LogicalPlan::Limit(limit, _) => { let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return not_impl_err!("Non-literal limit fetch"); @@ -348,7 +348,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(sort, _) => { let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; let sort_fields = sort .expr @@ -364,7 +364,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg, _) => { let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; let (grouping_expressions, groupings) = to_substrait_groupings( state, @@ -391,7 +391,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(Distinct::All(plan)) => { + LogicalPlan::Distinct(Distinct::All(plan), _) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` let input = to_substrait_rel(plan.as_ref(), state, extensions)?; // Get grouping keys from the input relation's number of output fields @@ -413,7 +413,7 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Join(join) => { + LogicalPlan::Join(join, _) => { let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; let join_type = to_substrait_jointype(join.join_type); @@ -483,12 +483,12 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::SubqueryAlias(alias) => { + LogicalPlan::SubqueryAlias(alias, _) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait to_substrait_rel(alias.input.as_ref(), state, extensions) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(union, _) => { let input_rels = union .inputs .iter() @@ -506,7 +506,7 @@ pub fn to_substrait_rel( })), })) } - LogicalPlan::Window(window) => { + LogicalPlan::Window(window, _) => { let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; // create a field reference for each input field @@ -545,7 +545,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } - LogicalPlan::Repartition(repartition) => { + LogicalPlan::Repartition(repartition, _) => { let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, @@ -591,7 +591,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), })) } - LogicalPlan::Extension(extension_plan) => { + LogicalPlan::Extension(extension_plan, _) => { let extension_bytes = state .serializer_registry() .serialize_logical_plan(extension_plan.node.as_ref())?; @@ -806,7 +806,7 @@ pub fn to_substrait_groupings( let mut ref_group_exprs = vec![]; let groupings = match exprs.len() { 1 => match &exprs[0] { - Expr::GroupingSet(gs) => match gs { + Expr::GroupingSet(gs, _) => match gs { GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( "GroupingSet CUBE is not yet supported".to_string(), )), @@ -869,7 +869,7 @@ pub fn to_substrait_agg_measure( extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { + Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }, _) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? } else { @@ -901,7 +901,7 @@ pub fn to_substrait_agg_measure( }) } - Expr::Alias(Alias{expr,..})=> { + Expr::Alias(Alias{expr,..}, _)=> { to_substrait_agg_measure(state, expr, schema, extensions) } _ => internal_err!( @@ -990,11 +990,14 @@ pub fn to_substrait_rex( extensions: &mut Extensions, ) -> Result { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList( + InList { + expr, + list, + negated, + }, + _, + ) => { let substrait_list = list .iter() .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) @@ -1027,7 +1030,7 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(fun) => { + Expr::ScalarFunction(fun, _) => { let mut arguments: Vec = vec![]; for arg in &fun.args { arguments.push(FunctionArgument { @@ -1052,12 +1055,15 @@ pub fn to_substrait_rex( })), }) } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { + Expr::Between( + Between { + expr, + negated, + low, + high, + }, + _, + ) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = @@ -1116,21 +1122,24 @@ pub fn to_substrait_rex( )) } } - Expr::Column(col) => { + Expr::Column(col, _) => { let index = schema.index_of_column(col)?; substrait_field_ref(index + col_ref_offset) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right }, _) => { let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { + Expr::Case( + Case { + expr, + when_then_expr, + else_expr, + }, + _, + ) => { let mut ifs: Vec = vec![]; // Parse base if let Some(e) = expr { @@ -1182,7 +1191,7 @@ pub fn to_substrait_rex( rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), }) } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, data_type }, _) => { Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { @@ -1199,18 +1208,21 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { + Expr::Literal(value, _) => to_substrait_literal_expr(value, extensions), + Expr::Alias(Alias { expr, .. }, _) => { to_substrait_rex(state, expr, schema, col_ref_offset, extensions) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::WindowFunction( + WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + _, + ) => { // function reference let function_anchor = extensions.register_function(fun.to_string()); // arguments @@ -1248,13 +1260,16 @@ pub fn to_substrait_rex( bound_type, )) } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( + Expr::Like( + Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }, + _, + ) => make_substrait_like_expr( state, *case_insensitive, *negated, @@ -1265,11 +1280,14 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery( + InSubquery { + expr, + subquery, + negated, + }, + _, + ) => { let substrait_expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; @@ -1306,7 +1324,7 @@ pub fn to_substrait_rex( Ok(substrait_subquery) } } - Expr::Not(arg) => to_substrait_unary_scalar_fn( + Expr::Not(arg, _) => to_substrait_unary_scalar_fn( state, "not", arg, @@ -1314,7 +1332,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + Expr::IsNull(arg, _) => to_substrait_unary_scalar_fn( state, "is_null", arg, @@ -1322,7 +1340,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotNull(arg, _) => to_substrait_unary_scalar_fn( state, "is_not_null", arg, @@ -1330,7 +1348,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + Expr::IsTrue(arg, _) => to_substrait_unary_scalar_fn( state, "is_true", arg, @@ -1338,7 +1356,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + Expr::IsFalse(arg, _) => to_substrait_unary_scalar_fn( state, "is_false", arg, @@ -1346,7 +1364,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + Expr::IsUnknown(arg, _) => to_substrait_unary_scalar_fn( state, "is_unknown", arg, @@ -1354,7 +1372,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotTrue(arg, _) => to_substrait_unary_scalar_fn( state, "is_not_true", arg, @@ -1362,7 +1380,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotFalse(arg, _) => to_substrait_unary_scalar_fn( state, "is_not_false", arg, @@ -1370,7 +1388,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + Expr::IsNotUnknown(arg, _) => to_substrait_unary_scalar_fn( state, "is_not_unknown", arg, @@ -1378,7 +1396,7 @@ pub fn to_substrait_rex( col_ref_offset, extensions, ), - Expr::Negative(arg) => to_substrait_unary_scalar_fn( + Expr::Negative(arg, _) => to_substrait_unary_scalar_fn( state, "negate", arg, @@ -2125,7 +2143,7 @@ fn try_to_substrait_field_reference( schema: &DFSchemaRef, ) -> Result { match expr { - Expr::Column(col) => { + Expr::Column(col, _) => { let index = schema.index_of_column(col)?; Ok(FieldReference { reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { @@ -2444,7 +2462,7 @@ mod test { let state = SessionStateBuilder::default().build(); // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let expr = Expr::literal(ScalarValue::Int32(Some(42))); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait = @@ -2459,8 +2477,8 @@ mod test { assert_eq!(rt_expr, &expr); // Multiple expressions, with column references - let expr1 = Expr::Column("c0".into()); - let expr2 = Expr::Column("c1".into()); + let expr1 = Expr::column("c0".into()); + let expr2 = Expr::column("c1".into()); let out1 = Field::new("out1", DataType::Int32, true); let out2 = Field::new("out2", DataType::Utf8, true); let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ @@ -2496,7 +2514,7 @@ mod test { let state = SessionStateBuilder::default().build(); // Not ok if input schema is missing field referenced by expr - let expr = Expr::Column("missing".into()); + let expr = Expr::column("missing".into()); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d03ab5182028a..115010e50d77b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -898,7 +898,7 @@ async fn roundtrip_values() -> Result<()> { async fn roundtrip_values_no_columns() -> Result<()> { let ctx = create_context().await?; // "VALUES ()" is not yet supported by the SQL parser, so we construct the plan manually - let plan = LogicalPlan::Values(Values { + let plan = LogicalPlan::values(Values { values: vec![vec![], vec![]], // two rows, no columns schema: DFSchemaRef::new(DFSchema::empty()), }); @@ -971,7 +971,7 @@ async fn new_test_grammar() -> Result<()> { async fn extension_logical_plan() -> Result<()> { let ctx = create_context().await?; let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec(); - let ext_plan = LogicalPlan::Extension(Extension { + let ext_plan = LogicalPlan::extension(Extension { node: Arc::new(MockUserDefinedLogicalPlan { validation_bytes, inputs: vec![], @@ -1076,7 +1076,7 @@ async fn roundtrip_window_udf() -> Result<()> { async fn roundtrip_repartition_roundrobin() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; - let plan = LogicalPlan::Repartition(Repartition { + let plan = LogicalPlan::repartition(Repartition { input: Arc::new(scan_plan), partitioning_scheme: Partitioning::RoundRobinBatch(8), }); @@ -1093,7 +1093,7 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { async fn roundtrip_repartition_hash() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; - let plan = LogicalPlan::Repartition(Repartition { + let plan = LogicalPlan::repartition(Repartition { input: Arc::new(scan_plan), partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index 556deb02e9800..e57faee5de56b 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -55,7 +55,7 @@ fn main() -> Result<(), DataFusionError> { let projection = None; // optional projection let filters = vec![]; // optional filters to push down let fetch = None; // optional LIMIT - let table_scan = LogicalPlan::TableScan(TableScan::try_new( + let table_scan = LogicalPlan::table_scan(TableScan::try_new( "person", Arc::new(table_source), projection, @@ -66,7 +66,7 @@ fn main() -> Result<(), DataFusionError> { // create a Filter plan that evaluates `id > 500` that wraps the TableScan let filter_expr = col("id").gt(lit(500)); - let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan)) ? ); + let plan = LogicalPlan::filter(Filter::try_new(filter_expr, Arc::new(table_scan)) ? ); // print the plan println!("{}", plan.display_indent_schema());