diff --git a/src/ast/query.rs b/src/ast/query.rs index e7020ae23..e94306358 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -328,6 +328,10 @@ pub struct Select { pub value_table_mode: Option, /// STARTING WITH .. CONNECT BY pub connect_by: Option, + /// Hive syntax: `SELECT ... GROUP BY .. GROUPING SETS` + /// + /// [Hive](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=30151323#EnhancedAggregation,Cube,GroupingandRollup-GROUPINGSETSclause) + pub grouping_sets: Option, } impl fmt::Display for Select { @@ -382,6 +386,23 @@ impl fmt::Display for Select { } } } + + if let Some(ref grouping_sets) = self.grouping_sets { + match grouping_sets { + Expr::GroupingSets(sets) => { + write!(f, " GROUPING SETS (")?; + let mut sep = ""; + for set in sets { + write!(f, "{sep}")?; + sep = ", "; + write!(f, "({})", display_comma_separated(set))?; + } + write!(f, ")")? + } + _ => unreachable!(), + } + } + if !self.cluster_by.is_empty() { write!( f, diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 183bebf8c..6882f98f2 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -2030,6 +2030,7 @@ impl Spanned for Select { prewhere, selection, group_by, + grouping_sets, cluster_by, distribute_by, sort_by, @@ -2051,6 +2052,7 @@ impl Spanned for Select { .chain(prewhere.iter().map(|item| item.span())) .chain(selection.iter().map(|item| item.span())) .chain(core::iter::once(group_by.span())) + .chain(grouping_sets.iter().map(|item| item.span())) .chain(cluster_by.iter().map(|item| item.span())) .chain(distribute_by.iter().map(|item| item.span())) .chain(sort_by.iter().map(|item| item.span())) diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 884dfcbcb..215bb4f33 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -66,4 +66,9 @@ impl Dialect for ClickHouseDialect { fn supports_dictionary_syntax(&self) -> bool { true } + + /// See + fn supports_group_by_with_modifier(&self) -> bool { + true + } } diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index d696861b5..f3f8390c5 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -48,6 +48,14 @@ impl Dialect for GenericDialect { true } + fn supports_group_by_with_modifier(&self) -> bool { + true + } + + fn supports_select_grouping_sets(&self) -> bool { + true + } + fn supports_connect_by(&self) -> bool { true } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 80f44cf7c..bbfd01782 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -52,18 +52,28 @@ impl Dialect for HiveDialect { true } - /// See Hive + /// See fn supports_bang_not_operator(&self) -> bool { true } - /// See Hive + /// See fn supports_load_data(&self) -> bool { true } - /// See Hive + /// See fn supports_table_sample_before_alias(&self) -> bool { true } + + /// See + fn supports_group_by_with_modifier(&self) -> bool { + true + } + + /// See + fn supports_select_grouping_sets(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index c66982d1f..a077e00a0 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -245,6 +245,16 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialects supports `with rollup/cube/all` expressions. + fn supports_group_by_with_modifier(&self) -> bool { + false + } + + /// Returns true if the dialects supports `select .. grouping sets` expressions. + fn supports_select_grouping_sets(&self) -> bool { + false + } + /// Returns true if the dialect supports CONNECT BY. fn supports_connect_by(&self) -> bool { false diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 3cf3c585e..21c77bc20 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8886,7 +8886,7 @@ impl<'a> Parser<'a> { }; let mut modifiers = vec![]; - if dialect_of!(self is ClickHouseDialect | GenericDialect) { + if self.dialect.supports_group_by_with_modifier() { loop { if !self.parse_keyword(Keyword::WITH) { break; @@ -10127,6 +10127,17 @@ impl<'a> Parser<'a> { .parse_optional_group_by()? .unwrap_or_else(|| GroupByExpr::Expressions(vec![], vec![])); + let grouping_sets = if self.dialect.supports_select_grouping_sets() + && self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS]) + { + self.expect_token(&Token::LParen)?; + let result = self.parse_comma_separated(|p| p.parse_tuple(true, true))?; + self.expect_token(&Token::RParen)?; + Some(Expr::GroupingSets(result)) + } else { + None + }; + let cluster_by = if self.parse_keywords(&[Keyword::CLUSTER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_expr)? } else { @@ -10198,6 +10209,7 @@ impl<'a> Parser<'a> { prewhere, selection, group_by, + grouping_sets, cluster_by, distribute_by, sort_by, diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index fed4308fc..6d6e43aa1 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -92,6 +92,7 @@ fn parse_map_access_expr() { }), }), group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1065,61 +1066,6 @@ fn parse_create_materialized_view() { clickhouse_and_generic().verified_stmt(sql); } -#[test] -fn parse_group_by_with_modifier() { - let clauses = ["x", "a, b", "ALL"]; - let modifiers = [ - "WITH ROLLUP", - "WITH CUBE", - "WITH TOTALS", - "WITH ROLLUP WITH CUBE", - ]; - let expected_modifiers = [ - vec![GroupByWithModifier::Rollup], - vec![GroupByWithModifier::Cube], - vec![GroupByWithModifier::Totals], - vec![GroupByWithModifier::Rollup, GroupByWithModifier::Cube], - ]; - for clause in &clauses { - for (modifier, expected_modifier) in modifiers.iter().zip(expected_modifiers.iter()) { - let sql = format!("SELECT * FROM t GROUP BY {clause} {modifier}"); - match clickhouse_and_generic().verified_stmt(&sql) { - Statement::Query(query) => { - let group_by = &query.body.as_select().unwrap().group_by; - if clause == &"ALL" { - assert_eq!(group_by, &GroupByExpr::All(expected_modifier.to_vec())); - } else { - assert_eq!( - group_by, - &GroupByExpr::Expressions( - clause - .split(", ") - .map(|c| Identifier(Ident::new(c))) - .collect(), - expected_modifier.to_vec() - ) - ); - } - } - _ => unreachable!(), - } - } - } - - // invalid cases - let invalid_cases = [ - "SELECT * FROM t GROUP BY x WITH", - "SELECT * FROM t GROUP BY x WITH ROLLUP CUBE", - "SELECT * FROM t GROUP BY x WITH WITH ROLLUP", - "SELECT * FROM t GROUP BY WITH ROLLUP", - ]; - for sql in invalid_cases { - clickhouse_and_generic() - .parse_sql_statements(sql) - .expect_err("Expected: one of ROLLUP or CUBE or TOTALS, found: WITH"); - } -} - #[test] fn parse_select_order_by_with_fill_interpolate() { let sql = "SELECT id, fname, lname FROM customer WHERE id < 5 \ diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ab69b48ae..63d4790f2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -408,6 +408,7 @@ fn parse_update_set_from() { vec![Expr::Identifier(Ident::new("id"))], vec![] ), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -2314,6 +2315,95 @@ fn parse_select_group_by_all() { ); } +#[test] +fn parse_group_by_with_modifier() { + let clauses = ["x", "a, b", "ALL"]; + let modifiers = [ + "WITH ROLLUP", + "WITH CUBE", + "WITH TOTALS", + "WITH ROLLUP WITH CUBE", + ]; + let expected_modifiers = [ + vec![GroupByWithModifier::Rollup], + vec![GroupByWithModifier::Cube], + vec![GroupByWithModifier::Totals], + vec![GroupByWithModifier::Rollup, GroupByWithModifier::Cube], + ]; + let dialects = all_dialects_where(|d| d.supports_group_by_with_modifier()); + + for clause in &clauses { + for (modifier, expected_modifier) in modifiers.iter().zip(expected_modifiers.iter()) { + let sql = format!("SELECT * FROM t GROUP BY {clause} {modifier}"); + match dialects.verified_stmt(&sql) { + Statement::Query(query) => { + let group_by = &query.body.as_select().unwrap().group_by; + if clause == &"ALL" { + assert_eq!(group_by, &GroupByExpr::All(expected_modifier.to_vec())); + } else { + assert_eq!( + group_by, + &GroupByExpr::Expressions( + clause + .split(", ") + .map(|c| Identifier(Ident::new(c))) + .collect(), + expected_modifier.to_vec() + ) + ); + } + } + _ => unreachable!(), + } + } + } + + // invalid cases + let invalid_cases = [ + "SELECT * FROM t GROUP BY x WITH", + "SELECT * FROM t GROUP BY x WITH ROLLUP CUBE", + "SELECT * FROM t GROUP BY x WITH WITH ROLLUP", + "SELECT * FROM t GROUP BY WITH ROLLUP", + ]; + for sql in invalid_cases { + dialects + .parse_sql_statements(sql) + .expect_err("Expected: one of ROLLUP or CUBE or TOTALS, found: WITH"); + } +} + +#[test] +fn parse_select_grouping_sets() { + let dialects = all_dialects_where(|d| d.supports_select_grouping_sets()); + + let sql = "SELECT a, b, SUM(c) FROM tab1 GROUP BY a, b GROUPING SETS ((a, b), (a), (b), ())"; + match dialects.verified_stmt(sql) { + Statement::Query(query) => { + let grouping_sets = &query.body.as_select().unwrap().grouping_sets; + assert_eq!( + grouping_sets, + &Some(Expr::GroupingSets(vec![ + vec![ + Expr::Identifier(Ident::new("a")), + Expr::Identifier(Ident::new("b")) + ], + vec![Expr::Identifier(Ident::new("a")),], + vec![Expr::Identifier(Ident::new("b"))], + vec![] + ])) + ); + } + _ => unreachable!(), + } + + let dialects = all_dialects_where(|d| !d.supports_select_grouping_sets()); + + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: GROUPING".to_string()) + ); +} + #[test] fn parse_select_having() { let sql = "SELECT foo FROM bar GROUP BY foo HAVING COUNT(*) > 1"; @@ -5105,6 +5195,7 @@ fn test_parse_named_window() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -5771,6 +5862,7 @@ fn parse_interval_and_or_xor() { }), }), group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -7829,6 +7921,7 @@ fn lateral_function() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -8717,6 +8810,7 @@ fn parse_merge() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -10357,6 +10451,7 @@ fn parse_unload() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -10551,6 +10646,7 @@ fn parse_connect_by() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -10635,6 +10731,7 @@ fn parse_connect_by() { right: Box::new(Expr::Value(number("42"))), }), group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -11511,6 +11608,7 @@ fn test_extract_seconds_ok() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs index db4ffb6f6..a4edcb547 100644 --- a/tests/sqlparser_duckdb.rs +++ b/tests/sqlparser_duckdb.rs @@ -279,6 +279,7 @@ fn test_select_union_by_name() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -308,6 +309,7 @@ fn test_select_union_by_name() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 567cd5382..9ea3a799c 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -127,6 +127,7 @@ fn parse_create_procedure() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1097,6 +1098,7 @@ fn parse_substring_in_select() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1234,6 +1236,7 @@ fn parse_mssql_declare() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index dcf3f57fe..15de4007e 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1099,6 +1099,7 @@ fn parse_escaped_quote_identifiers_with_escape() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1152,6 +1153,7 @@ fn parse_escaped_quote_identifiers_with_no_escape() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1199,6 +1201,7 @@ fn parse_escaped_backticks_with_escape() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1250,6 +1253,7 @@ fn parse_escaped_backticks_with_no_escape() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1919,6 +1923,7 @@ fn parse_select_with_numeric_prefix_column_name() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -1971,6 +1976,7 @@ fn parse_select_with_concatenation_of_exp_number_and_numeric_prefix_column() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -2489,6 +2495,7 @@ fn parse_substring_in_select() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -2784,6 +2791,7 @@ fn parse_hex_string_introducer() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 864fb5eb3..eea334501 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1306,6 +1306,7 @@ fn parse_copy_to() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, having: None, named_window: vec![], window_before_qualify: false, @@ -2644,6 +2645,7 @@ fn parse_array_subquery_expr() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![], @@ -2666,6 +2668,7 @@ fn parse_array_subquery_expr() { prewhere: None, selection: None, group_by: GroupByExpr::Expressions(vec![], vec![]), + grouping_sets: None, cluster_by: vec![], distribute_by: vec![], sort_by: vec![],