Skip to content

Commit

Permalink
[fix](nereids) remove useless cast in in-predicate (apache#23171)
Browse files Browse the repository at this point in the history
consider sql "select * from test_simplify_in_predicate_t where a in ('1992-01-31', '1992-02-01', '1992-02-02', '1992-02-03', '1992-02-04');"
before:

```
|   0:VOlapScanNode                                                                                                                                                                                      |
|      TABLE: default_cluster:bugfix.test_simplify_in_predicate_t(test_simplify_in_predicate_t), PREAGGREGATION: OFF. Reason: No aggregate on scan.                                                      |
|      PREDICATES: CAST(a[#0] AS DATETIMEV2(0)) IN ('1992-01-31 00:00:00', '1992-02-01 00:00:00', '1992-02-02 00:00:00', '1992-02-03 00:00:00', '1992-02-04 00:00:00') AND __DORIS_DELETE_SIGN__[#1] = 0 |
|      partitions=0/1, tablets=0/0, tabletList=                                                                                                                                                          |
|      cardinality=1, avgRowSize=0.0, numNodes=1                                                                                                                                                         |
|      pushAggOp=NONE                                                                                                                                                                                    |
|      projections: a[#0]                                                                                                                                                                                |
|      project output tuple id: 1                                                                                                                                                                        |
|      tuple ids: 0  
```
after:

```
|   0:VOlapScanNode                                                                                                                                 |
|      TABLE: default_cluster:bugfix.test_simplify_in_predicate_t(test_simplify_in_predicate_t), PREAGGREGATION: OFF. Reason: No aggregate on scan. |
|      PREDICATES: a[#0] IN ('1992-01-31', '1992-02-01', '1992-02-02', '1992-02-03', '1992-02-04') AND __DORIS_DELETE_SIGN__[#1] = 0                |
|      partitions=0/1, tablets=0/0, tabletList=                                                                                                     |
|      cardinality=1, avgRowSize=0.0, numNodes=1                                                                                                    |
|      pushAggOp=NONE                                                                                                                               |
|      projections: a[#0]                                                                                                                           |
|      project output tuple id: 1                                                                                                                   |
|      tuple ids: 0  

```
  • Loading branch information
starocean999 authored and xiaokang committed Aug 24, 2023
1 parent 1acdf55 commit c749bc9
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;

import com.google.common.collect.ImmutableList;
Expand All @@ -36,6 +37,7 @@ public class ExpressionOptimization extends ExpressionRewrite {
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
SimplifyInPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
SimplifyRange.INSTANCE,
OrToIn.INSTANCE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;

import com.google.common.collect.Lists;

import java.util.List;

/**
* SimplifyInPredicate
*/
public class SimplifyInPredicate extends AbstractExpressionRewriteRule {

public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate();

@Override
public Expression visitInPredicate(InPredicate expr, ExpressionRewriteContext context) {
if (expr.children().size() > 1) {
if (expr.getCompareExpr() instanceof Cast) {
Cast cast = (Cast) expr.getCompareExpr();
if (cast.child().getDataType().isDateV2Type()
&& expr.child(1) instanceof DateTimeV2Literal) {
List<Expression> literals = expr.children().subList(1, expr.children().size());
if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal
&& canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) {
List<Expression> children = Lists.newArrayList();
children.add(cast.child());
literals.stream().forEach(
l -> children.add(convertToDateV2Literal((DateTimeV2Literal) l)));
return expr.withChildren(children);
}
}
}
}
return expr;
}

/*
derive tree:
DateLiteral
|
+--->DateTimeLiteral
| |
| +----->DateTimeV2Literal
+--->DateV2Literal
*/
private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal literal) {
return (literal.getHour() | literal.getMinute() | literal.getSecond()
| literal.getMicroSecond()) == 0L;
}

private DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) {
return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
Expand Down Expand Up @@ -85,7 +86,7 @@ protected void assertRewriteAfterTypeCoercion(String expression, String expected
Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql());
}

private Expression replaceUnboundSlot(Expression expression, Map<String, Slot> mem) {
protected Expression replaceUnboundSlot(Expression expression, Map<String, Slot> mem) {
List<Expression> children = Lists.newArrayList();
boolean hasNewChildren = false;
for (Expression child : expression.children()) {
Expand All @@ -103,7 +104,7 @@ private Expression replaceUnboundSlot(Expression expression, Map<String, Slot> m
return hasNewChildren ? expression.withChildren(children) : expression;
}

private Expression typeCoercion(Expression expression) {
protected Expression typeCoercion(Expression expression) {
return FunctionBinder.INSTANCE.rewrite(expression, null);
}

Expand All @@ -121,6 +122,8 @@ private DataType getType(char t) {
return VarcharType.SYSTEM_DEFAULT;
case 'B':
return BooleanType.INSTANCE;
case 'C':
return DateV2Type.INSTANCE;
default:
return BigIntType.INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression;

import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Map;

public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper {

@Test
public void test() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
FoldConstantRule.INSTANCE,
SimplifyInPredicate.INSTANCE
));
Map<String, Slot> mem = Maps.newHashMap();
Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')");
rewrittenExpression = typeCoercion(replaceUnboundSlot(rewrittenExpression, mem));
rewrittenExpression = executor.rewrite(rewrittenExpression, context);
Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))");
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
executor = new ExpressionRuleExecutor(ImmutableList.of(
FoldConstantRule.INSTANCE
));
expectedExpression = executor.rewrite(expectedExpression, context);
Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ PhysicalResultSink
--------------------------------PhysicalOlapScan[date_dim]
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00))
----------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11))
------------------------------------PhysicalOlapScan[date_dim]
------------hashJoin[INNER_JOIN](sr_items.item_id = wr_items.item_id)
--------------PhysicalProject
Expand All @@ -57,7 +57,7 @@ PhysicalResultSink
----------------------------------PhysicalOlapScan[date_dim]
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00))
------------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11))
--------------------------------------PhysicalOlapScan[date_dim]
--------------PhysicalProject
----------------hashAgg[GLOBAL]
Expand All @@ -83,6 +83,6 @@ PhysicalResultSink
----------------------------------PhysicalOlapScan[date_dim]
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00))
------------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11))
--------------------------------------PhysicalOlapScan[date_dim]

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_simplify_in_predicate") {
sql "set enable_nereids_planner=true"
sql 'set enable_fallback_to_original_planner=false;'
sql 'drop table if exists test_simplify_in_predicate_t'
sql """CREATE TABLE IF NOT EXISTS `test_simplify_in_predicate_t` (
a DATE NOT NULL
) ENGINE=OLAP
UNIQUE KEY (`a`)
DISTRIBUTED BY HASH(`a`) BUCKETS 120
PROPERTIES (
"replication_num" = "1",
"in_memory" = "false",
"compression" = "LZ4"
);"""
sql """insert into test_simplify_in_predicate_t values( "2023-06-06" );"""

explain {
sql "verbose select * from test_simplify_in_predicate_t where a in ('1992-01-31', '1992-02-01', '1992-02-02', '1992-02-03', '1992-02-04');"
notContains "CAST"
}
}

0 comments on commit c749bc9

Please sign in to comment.