Skip to content

Commit

Permalink
[CALCITE-4345] AggregateCaseToFilterRule throws NullPointerException …
Browse files Browse the repository at this point in the history
…when converting CASE without ELSE (Jiatao Tao)

For example, 'SUM(CASE WHEN b THEN 1 END)' is equivalent to
'SUM(CASE WHEN b THEN 1 ELSE NULL END)', and both should be
converted to 'SUM(1) FILTER (WHERE b)', but before this bug
was fixed the former would throw NullPointerException.

Close apache#2225
  • Loading branch information
Aaaaaaron authored and XuQianJin-Stars committed Jul 14, 2021
1 parent e154d75 commit ce75aa1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import com.google.common.collect.ImmutableList;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -227,8 +228,8 @@ && isThreeArgCase(project.getProjects().get(singleArg))) {
RelCollations.EMPTY, aggregateCall.getType(),
aggregateCall.getName());
} else if (kind == SqlKind.SUM // Case B
&& isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
&& isIntLiteral(arg1, BigDecimal.ONE)
&& isIntLiteral(arg2, BigDecimal.ZERO)) {

newProjects.add(filter);
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
Expand All @@ -241,8 +242,7 @@ && isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
&& aggregateCall.getAggregation().allowsFilter())
|| (kind == SqlKind.SUM // Case A2
&& isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0)) {
&& isIntLiteral(arg2, BigDecimal.ZERO))) {
newProjects.add(arg1);
newProjects.add(filter);
return AggregateCall.create(aggregateCall.getAggregation(), false,
Expand All @@ -267,9 +267,10 @@ private static boolean isThreeArgCase(final RexNode rexNode) {
&& ((RexCall) rexNode).operands.size() == 3;
}

private static boolean isIntLiteral(final RexNode rexNode) {
private static boolean isIntLiteral(RexNode rexNode, BigDecimal value) {
return rexNode instanceof RexLiteral
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName());
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
&& value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
}

/** Rule configuration. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,10 @@ public boolean test(Project project) {
+ " sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,\n"
+ " sum(case when deptno = 30 then 1 else 0 end) as count_d30,\n"
+ " count(case when deptno = 40 then 'x' end) as count_d40,\n"
+ " sum(case when deptno = 45 then 1 end) as count_d45,\n"
+ " sum(case when deptno = 50 then 1 else null end) as count_d50,\n"
+ " sum(case when deptno = 60 then null end) as sum_null_d60,\n"
+ " sum(case when deptno = 70 then null else 1 end) as sum_null_d70,\n"
+ " count(case when deptno = 20 then 1 end) as count_d20\n"
+ "from emp";
sql(sql).withRule(CoreRules.AGGREGATE_CASE_TO_FILTER).check();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,25 @@
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
count(case when deptno = 40 then 'x' end) as count_d40,
sum(case when deptno = 45 then 1 end) as count_d45,
sum(case when deptno = 50 then 1 else null end) as count_d50,
sum(case when deptno = 60 then null end) as sum_null_d60,
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
count(case when deptno = 20 then 1 end) as count_d20
from emp]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1)], SUM_SAL_D10=[SUM($2)], SUM_SAL_D20=[SUM($3)], COUNT_D30=[SUM($4)], COUNT_D40=[COUNT($5)], COUNT_D20=[COUNT($6)])
LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 20), 1, null:INTEGER)])
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1)], SUM_SAL_D10=[SUM($2)], SUM_SAL_D20=[SUM($3)], COUNT_D30=[SUM($4)], COUNT_D40=[COUNT($5)], COUNT_D45=[SUM($6)], COUNT_D50=[SUM($7)], SUM_NULL_D60=[SUM($8)], SUM_NULL_D70=[SUM($9)], COUNT_D20=[COUNT($10)])
LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 45), 1, null:INTEGER)], $f7=[CASE(=($7, 50), 1, null:INTEGER)], $f8=[null:DECIMAL(19, 9)], $f9=[CASE(=($7, 70), null:INTEGER, 1)], $f10=[CASE(=($7, 20), 1, null:INTEGER)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D20=[$6])
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1) FILTER $2], SUM_SAL_D10=[SUM($3) FILTER $4], SUM_SAL_D20=[SUM($5) FILTER $6], COUNT_D30=[COUNT() FILTER $7], COUNT_D40=[COUNT() FILTER $8], COUNT_D20=[COUNT() FILTER $9])
LogicalProject(SAL=[$5], DEPTNO=[$7], $f8=[=($2, 'CLERK')], SAL0=[$5], $f10=[=($7, 10)], SAL1=[$5], $f12=[=($7, 20)], $f13=[=($7, 30)], $f14=[=($7, 40)], $f15=[=($7, 20)])
LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D45=[$6], COUNT_D50=[$7], SUM_NULL_D60=[$8], SUM_NULL_D70=[$9], COUNT_D20=[$10])
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $2) FILTER $3], SUM_SAL_D10=[SUM($4) FILTER $5], SUM_SAL_D20=[SUM($6) FILTER $7], COUNT_D30=[COUNT() FILTER $8], COUNT_D40=[COUNT() FILTER $9], COUNT_D45=[SUM($10) FILTER $11], COUNT_D50=[SUM($12) FILTER $13], SUM_NULL_D60=[SUM($1)], SUM_NULL_D70=[SUM($14) FILTER $15], COUNT_D20=[COUNT() FILTER $16])
LogicalProject(SAL=[$5], $f8=[null:DECIMAL(19, 9)], DEPTNO=[$7], $f12=[=($2, 'CLERK')], SAL0=[$5], $f14=[=($7, 10)], SAL1=[$5], $f16=[=($7, 20)], $f17=[=($7, 30)], $f18=[=($7, 40)], $f19=[1], $f20=[=($7, 45)], $f21=[1], $f22=[=($7, 50)], $f23=[1], $f24=[<>($7, 70)], $f25=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down
51 changes: 51 additions & 0 deletions core/src/test/resources/sql/agg.iq
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,57 @@ EnumerableCalc(expr#0=[{inputs}], expr#1=[0:BIGINT], expr#2=[=($t0, $t1)], expr#

!use scott

# [CALCITE-4345] SUM(CASE WHEN b THEN 1) etc.
select
sum(sal) as sum_sal,
count(distinct case
when job = 'CLERK'
then deptno else null end) as count_distinct_clerk,
sum(case when deptno = 10 then sal end) as sum_sal_d10,
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
count(case when deptno = 40 then 'x' end) as count_d40,
sum(case when deptno = 45 then 1 end) as count_d45,
sum(case when deptno = 50 then 1 else null end) as count_d50,
sum(case when deptno = 60 then null end) as sum_null_d60,
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
count(case when deptno = 20 then 1 end) as count_d20
from emp;
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
| SUM_SAL | COUNT_DISTINCT_CLERK | SUM_SAL_D10 | SUM_SAL_D20 | COUNT_D30 | COUNT_D40 | COUNT_D45 | COUNT_D50 | SUM_NULL_D60 | SUM_NULL_D70 | COUNT_D20 |
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
| 29025.00 | 3 | 8750.00 | 10875.00 | 6 | 0 | | | | 14 | 5 |
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
(1 row)

!ok

# Check that SUM produces NULL on empty set, COUNT produces 0.
select
sum(sal) as sum_sal,
count(distinct case
when job = 'CLERK'
then deptno else null end) as count_distinct_clerk,
sum(case when deptno = 10 then sal end) as sum_sal_d10,
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
count(case when deptno = 40 then 'x' end) as count_d40,
sum(case when deptno = 45 then 1 end) as count_d45,
sum(case when deptno = 50 then 1 else null end) as count_d50,
sum(case when deptno = 60 then null end) as sum_null_d60,
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
count(case when deptno = 20 then 1 end) as count_d20
from emp
where false;
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
| SUM_SAL | COUNT_DISTINCT_CLERK | SUM_SAL_D10 | SUM_SAL_D20 | COUNT_D30 | COUNT_D40 | COUNT_D45 | COUNT_D50 | SUM_NULL_D60 | SUM_NULL_D70 | COUNT_D20 |
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
| | 0 | | | | 0 | | | | | 0 |
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
(1 row)

!ok

# [CALCITE-1930] AggregateExpandDistinctAggregateRules should handle multiple aggregate calls with same input ref
select count(distinct EMPNO), COUNT(SAL), MIN(SAL), MAX(SAL) from "scott".emp;
+--------+--------+--------+---------+
Expand Down

0 comments on commit ce75aa1

Please sign in to comment.