Skip to content

Commit

Permalink
reserve literal exprs (#4031)
Browse files Browse the repository at this point in the history
Signed-off-by: remzi <[email protected]>

Signed-off-by: remzi <[email protected]>
  • Loading branch information
HaoYang670 authored Nov 2, 2022
1 parent d096215 commit 065a478
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,14 @@ async fn csv_query_count_star() {
}

#[tokio::test]
async fn csv_query_count_one() {
async fn csv_query_count_literal() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
let sql = "SELECT COUNT(1) FROM aggregate_test_100";
let sql = "SELECT COUNT(2) FROM aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"| COUNT(Int64(2)) |",
"+-----------------+",
"| 100 |",
"+-----------------+",
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn csv_query_with_negated_predicate() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"| COUNT(Int64(1)) |",
"+-----------------+",
"| 21 |",
"+-----------------+",
Expand All @@ -78,7 +78,7 @@ async fn csv_query_with_is_not_null_predicate() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"| COUNT(Int64(1)) |",
"+-----------------+",
"| 100 |",
"+-----------------+",
Expand All @@ -95,7 +95,7 @@ async fn csv_query_with_is_null_predicate() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"| COUNT(Int64(1)) |",
"+-----------------+",
"| 0 |",
"+-----------------+",
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ fn between_date32_plus_interval() -> Result<()> {
WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'";
let plan = test_sql(sql)?;
let expected =
"Projection: COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
"Projection: COUNT(Int64(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\
\n TableScan: test projection=[col_date32]";
assert_eq!(expected, format!("{:?}", plan));
Expand All @@ -193,7 +193,7 @@ fn between_date64_plus_interval() -> Result<()> {
WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as date) + INTERVAL '90 days'";
let plan = test_sql(sql)?;
let expected =
"Projection: COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
"Projection: COUNT(Int64(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\
\n TableScan: test projection=[col_date64]";
assert_eq!(expected, format!("{:?}", plan));
Expand Down
26 changes: 11 additions & 15 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) => {
let (aggregate_fun, args) = self.aggregate_fn_to_expr(
aggregate_fun,
function,
function.args,
schema,
)?;

Expand Down Expand Up @@ -2220,7 +2220,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// next, aggregate built-ins
if let Ok(fun) = AggregateFunction::from_str(&name) {
let distinct = function.distinct;
let (fun, args) = self.aggregate_fn_to_expr(fun, function, schema)?;
let (fun, args) = self.aggregate_fn_to_expr(fun, function.args, schema)?;
return Ok(Expr::AggregateFunction {
fun,
distinct,
Expand Down Expand Up @@ -2344,25 +2344,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn aggregate_fn_to_expr(
&self,
fun: AggregateFunction,
function: sqlparser::ast::Function,
args: Vec<FunctionArg>,
schema: &DFSchema,
) -> Result<(AggregateFunction, Vec<Expr>)> {
let args = match fun {
// Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => function
.args
AggregateFunction::Count => args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value(
Value::Number(_, _),
))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())),
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
}
_ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()),
})
.collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(function.args, schema)?,
_ => self.function_args_to_expr(args, schema)?,
};

Ok((fun, args))
Expand Down Expand Up @@ -3662,14 +3658,14 @@ mod tests {
fn select_simple_aggregate_with_groupby_can_use_positions() {
quick_test(
"SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2",
"Projection: person.state, person.age AS b, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[person.state, person.age]], aggr=[[COUNT(UInt8(1))]]\
"Projection: person.state, person.age AS b, COUNT(Int64(1))\
\n Aggregate: groupBy=[[person.state, person.age]], aggr=[[COUNT(Int64(1))]]\
\n TableScan: person",
);
quick_test(
"SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1",
"Projection: person.state, person.age AS b, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[person.age, person.state]], aggr=[[COUNT(UInt8(1))]]\
"Projection: person.state, person.age AS b, COUNT(Int64(1))\
\n Aggregate: groupBy=[[person.age, person.state]], aggr=[[COUNT(Int64(1))]]\
\n TableScan: person",
);
}
Expand Down Expand Up @@ -3834,8 +3830,8 @@ mod tests {
#[test]
fn select_count_one() {
let sql = "SELECT COUNT(1) FROM person";
let expected = "Projection: COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
let expected = "Projection: COUNT(Int64(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down

0 comments on commit 065a478

Please sign in to comment.