Skip to content

Commit

Permalink
Add supports for Hive's SELECT ... GROUP BY .. GROUPING SETS syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
wugeer committed Jan 9, 2025
1 parent 687ce2d commit 9779f2e
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 59 deletions.
21 changes: 21 additions & 0 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ pub struct Select {
pub value_table_mode: Option<ValueTableMode>,
/// STARTING WITH .. CONNECT BY
pub connect_by: Option<ConnectBy>,
/// Hive syntax: `SELECT ... GROUP BY .. GROUPING SETS`
///
/// [Hive](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=30151323#EnhancedAggregation,Cube,GroupingandRollup-GROUPINGSETSclause)
pub grouping_sets: Option<Expr>,
}

impl fmt::Display for Select {
Expand Down Expand Up @@ -380,6 +384,23 @@ impl fmt::Display for Select {
}
}
}

if let Some(ref grouping_sets) = self.grouping_sets {
match grouping_sets {
Expr::GroupingSets(sets) => {
write!(f, " GROUPING SETS (")?;
let mut sep = "";
for set in sets {
write!(f, "{sep}")?;
sep = ", ";
write!(f, "({})", display_comma_separated(set))?;
}
write!(f, ")")?
}
_ => unreachable!(),
}
}

if !self.cluster_by.is_empty() {
write!(
f,
Expand Down
2 changes: 2 additions & 0 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,7 @@ impl Spanned for Select {
prewhere,
selection,
group_by,
grouping_sets,
cluster_by,
distribute_by,
sort_by,
Expand All @@ -2046,6 +2047,7 @@ impl Spanned for Select {
.chain(prewhere.iter().map(|item| item.span()))
.chain(selection.iter().map(|item| item.span()))
.chain(core::iter::once(group_by.span()))
.chain(grouping_sets.iter().map(|item| item.span()))
.chain(cluster_by.iter().map(|item| item.span()))
.chain(distribute_by.iter().map(|item| item.span()))
.chain(sort_by.iter().map(|item| item.span()))
Expand Down
5 changes: 5 additions & 0 deletions src/dialect/clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ impl Dialect for ClickHouseDialect {
fn supports_limit_comma(&self) -> bool {
true
}

/// See <https://clickhouse.com/docs/en/sql-reference/statements/select/group-by#rollup-modifier>
fn supports_group_by_with_modifier(&self) -> bool {
true
}
}
8 changes: 8 additions & 0 deletions src/dialect/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ impl Dialect for GenericDialect {
true
}

fn supports_group_by_with_modifier(&self) -> bool {
true
}

fn supports_select_grouping_sets(&self) -> bool {
true
}

fn supports_connect_by(&self) -> bool {
true
}
Expand Down
16 changes: 13 additions & 3 deletions src/dialect/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,28 @@ impl Dialect for HiveDialect {
true
}

/// See Hive <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362061#Tutorial-BuiltInOperators>
/// See <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362061#Tutorial-BuiltInOperators>
fn supports_bang_not_operator(&self) -> bool {
true
}

/// See Hive <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
/// See <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
fn supports_load_data(&self) -> bool {
true
}

/// See Hive <https://cwiki.apache.org/confluence/display/hive/languagemanual+sampling>
/// See <https://cwiki.apache.org/confluence/display/hive/languagemanual+sampling>
fn supports_table_sample_before_alias(&self) -> bool {
true
}

/// See <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=30151323#EnhancedAggregation,Cube,GroupingandRollup-CubesandRollupsr>
fn supports_group_by_with_modifier(&self) -> bool {
true
}

/// See <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=30151323#EnhancedAggregation,Cube,GroupingandRollup-GROUPINGSETSclause>
fn supports_select_grouping_sets(&self) -> bool {
true
}
}
10 changes: 10 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialects supports `with rollup/cube/all` expressions.
fn supports_group_by_with_modifier(&self) -> bool {
false
}

/// Returns true if the dialects supports `select .. grouping sets` expressions.
fn supports_select_grouping_sets(&self) -> bool {
false
}

/// Returns true if the dialect supports CONNECT BY.
fn supports_connect_by(&self) -> bool {
false
Expand Down
14 changes: 13 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8821,7 +8821,7 @@ impl<'a> Parser<'a> {
};

let mut modifiers = vec![];
if dialect_of!(self is ClickHouseDialect | GenericDialect) {
if self.dialect.supports_group_by_with_modifier() {
loop {
if !self.parse_keyword(Keyword::WITH) {
break;
Expand Down Expand Up @@ -10020,6 +10020,17 @@ impl<'a> Parser<'a> {
.parse_optional_group_by()?
.unwrap_or_else(|| GroupByExpr::Expressions(vec![], vec![]));

let grouping_sets = if self.dialect.supports_select_grouping_sets()
&& self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS])
{
self.expect_token(&Token::LParen)?;
let result = self.parse_comma_separated(|p| p.parse_tuple(true, true))?;
self.expect_token(&Token::RParen)?;
Some(Expr::GroupingSets(result))
} else {
None
};

let cluster_by = if self.parse_keywords(&[Keyword::CLUSTER, Keyword::BY]) {
self.parse_comma_separated(Parser::parse_expr)?
} else {
Expand Down Expand Up @@ -10091,6 +10102,7 @@ impl<'a> Parser<'a> {
prewhere,
selection,
group_by,
grouping_sets,
cluster_by,
distribute_by,
sort_by,
Expand Down
56 changes: 1 addition & 55 deletions tests/sqlparser_clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ fn parse_map_access_expr() {
}),
}),
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -1059,61 +1060,6 @@ fn parse_create_materialized_view() {
clickhouse_and_generic().verified_stmt(sql);
}

#[test]
fn parse_group_by_with_modifier() {
let clauses = ["x", "a, b", "ALL"];
let modifiers = [
"WITH ROLLUP",
"WITH CUBE",
"WITH TOTALS",
"WITH ROLLUP WITH CUBE",
];
let expected_modifiers = [
vec![GroupByWithModifier::Rollup],
vec![GroupByWithModifier::Cube],
vec![GroupByWithModifier::Totals],
vec![GroupByWithModifier::Rollup, GroupByWithModifier::Cube],
];
for clause in &clauses {
for (modifier, expected_modifier) in modifiers.iter().zip(expected_modifiers.iter()) {
let sql = format!("SELECT * FROM t GROUP BY {clause} {modifier}");
match clickhouse_and_generic().verified_stmt(&sql) {
Statement::Query(query) => {
let group_by = &query.body.as_select().unwrap().group_by;
if clause == &"ALL" {
assert_eq!(group_by, &GroupByExpr::All(expected_modifier.to_vec()));
} else {
assert_eq!(
group_by,
&GroupByExpr::Expressions(
clause
.split(", ")
.map(|c| Identifier(Ident::new(c)))
.collect(),
expected_modifier.to_vec()
)
);
}
}
_ => unreachable!(),
}
}
}

// invalid cases
let invalid_cases = [
"SELECT * FROM t GROUP BY x WITH",
"SELECT * FROM t GROUP BY x WITH ROLLUP CUBE",
"SELECT * FROM t GROUP BY x WITH WITH ROLLUP",
"SELECT * FROM t GROUP BY WITH ROLLUP",
];
for sql in invalid_cases {
clickhouse_and_generic()
.parse_sql_statements(sql)
.expect_err("Expected: one of ROLLUP or CUBE or TOTALS, found: WITH");
}
}

#[test]
fn parse_select_order_by_with_fill_interpolate() {
let sql = "SELECT id, fname, lname FROM customer WHERE id < 5 \
Expand Down
97 changes: 97 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ fn parse_update_set_from() {
vec![Expr::Identifier(Ident::new("id"))],
vec![]
),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -2305,6 +2306,95 @@ fn parse_select_group_by_all() {
);
}

#[test]
fn parse_group_by_with_modifier() {
let clauses = ["x", "a, b", "ALL"];
let modifiers = [
"WITH ROLLUP",
"WITH CUBE",
"WITH TOTALS",
"WITH ROLLUP WITH CUBE",
];
let expected_modifiers = [
vec![GroupByWithModifier::Rollup],
vec![GroupByWithModifier::Cube],
vec![GroupByWithModifier::Totals],
vec![GroupByWithModifier::Rollup, GroupByWithModifier::Cube],
];
let dialects = all_dialects_where(|d| d.supports_group_by_with_modifier());

for clause in &clauses {
for (modifier, expected_modifier) in modifiers.iter().zip(expected_modifiers.iter()) {
let sql = format!("SELECT * FROM t GROUP BY {clause} {modifier}");
match dialects.verified_stmt(&sql) {
Statement::Query(query) => {
let group_by = &query.body.as_select().unwrap().group_by;
if clause == &"ALL" {
assert_eq!(group_by, &GroupByExpr::All(expected_modifier.to_vec()));
} else {
assert_eq!(
group_by,
&GroupByExpr::Expressions(
clause
.split(", ")
.map(|c| Identifier(Ident::new(c)))
.collect(),
expected_modifier.to_vec()
)
);
}
}
_ => unreachable!(),
}
}
}

// invalid cases
let invalid_cases = [
"SELECT * FROM t GROUP BY x WITH",
"SELECT * FROM t GROUP BY x WITH ROLLUP CUBE",
"SELECT * FROM t GROUP BY x WITH WITH ROLLUP",
"SELECT * FROM t GROUP BY WITH ROLLUP",
];
for sql in invalid_cases {
dialects
.parse_sql_statements(sql)
.expect_err("Expected: one of ROLLUP or CUBE or TOTALS, found: WITH");
}
}

#[test]
fn parse_select_grouping_sets() {
let dialects = all_dialects_where(|d| d.supports_select_grouping_sets());

let sql = "SELECT a, b, SUM(c) FROM tab1 GROUP BY a, b GROUPING SETS ((a, b), (a), (b), ())";
match dialects.verified_stmt(&sql) {
Statement::Query(query) => {
let grouping_sets = &query.body.as_select().unwrap().grouping_sets;
assert_eq!(
grouping_sets,
&Some(Expr::GroupingSets(vec![
vec![
Expr::Identifier(Ident::new("a")),
Expr::Identifier(Ident::new("b"))
],
vec![Expr::Identifier(Ident::new("a")),],
vec![Expr::Identifier(Ident::new("b"))],
vec![]
]))
);
}
_ => unreachable!(),
}

let dialects = all_dialects_where(|d| !d.supports_select_grouping_sets());

assert_eq!(
dialects.parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError("Expected: end of statement, found: GROUPING".to_string())
);
}

#[test]
fn parse_select_having() {
let sql = "SELECT foo FROM bar GROUP BY foo HAVING COUNT(*) > 1";
Expand Down Expand Up @@ -5098,6 +5188,7 @@ fn test_parse_named_window() {
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -5764,6 +5855,7 @@ fn parse_interval_and_or_xor() {
}),
}),
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -7805,6 +7897,7 @@ fn lateral_function() {
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -8667,6 +8760,7 @@ fn parse_merge() {
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -10307,6 +10401,7 @@ fn parse_unload() {
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -10501,6 +10596,7 @@ fn parse_connect_by() {
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down Expand Up @@ -10585,6 +10681,7 @@ fn parse_connect_by() {
right: Box::new(Expr::Value(number("42"))),
}),
group_by: GroupByExpr::Expressions(vec![], vec![]),
grouping_sets: None,
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
Expand Down
Loading

0 comments on commit 9779f2e

Please sign in to comment.