diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 78676ca3079c74..53f360d5ffe6c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.OrToIn; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; +import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate; import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import com.google.common.collect.ImmutableList; @@ -36,6 +37,7 @@ public class ExpressionOptimization extends ExpressionRewrite { ExtractCommonFactorRule.INSTANCE, DistinctPredicatesRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE, + SimplifyInPredicate.INSTANCE, SimplifyDecimalV3Comparison.INSTANCE, SimplifyRange.INSTANCE, OrToIn.INSTANCE diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java new file mode 100644 index 00000000000000..8feb52cd4bcdbe --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; + +import com.google.common.collect.Lists; + +import java.util.List; + +/** + * SimplifyInPredicate + */ +public class SimplifyInPredicate extends AbstractExpressionRewriteRule { + + public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate(); + + @Override + public Expression visitInPredicate(InPredicate expr, ExpressionRewriteContext context) { + if (expr.children().size() > 1) { + if (expr.getCompareExpr() instanceof Cast) { + Cast cast = (Cast) expr.getCompareExpr(); + if (cast.child().getDataType().isDateV2Type() + && expr.child(1) instanceof DateTimeV2Literal) { + List literals = expr.children().subList(1, expr.children().size()); + if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal + && canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) { + List children = Lists.newArrayList(); + children.add(cast.child()); + literals.stream().forEach( + l -> children.add(convertToDateV2Literal((DateTimeV2Literal) l))); + return expr.withChildren(children); + } + } + } + } + return expr; + } + + /* + derive tree: + DateLiteral + | + +--->DateTimeLiteral + | | + | +----->DateTimeV2Literal + +--->DateV2Literal + */ + private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal literal) { + return (literal.getHour() | literal.getMinute() | literal.getSecond() + | literal.getMicroSecond()) == 0L; + } + + private DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) { + return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay()); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java index 9cd4ea49c6463c..e8cc4dd266a9db 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; @@ -85,7 +86,7 @@ protected void assertRewriteAfterTypeCoercion(String expression, String expected Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql()); } - private Expression replaceUnboundSlot(Expression expression, Map mem) { + protected Expression replaceUnboundSlot(Expression expression, Map mem) { List children = Lists.newArrayList(); boolean hasNewChildren = false; for (Expression child : expression.children()) { @@ -103,7 +104,7 @@ private Expression replaceUnboundSlot(Expression expression, Map m return hasNewChildren ? expression.withChildren(children) : expression; } - private Expression typeCoercion(Expression expression) { + protected Expression typeCoercion(Expression expression) { return FunctionBinder.INSTANCE.rewrite(expression, null); } @@ -121,6 +122,8 @@ private DataType getType(char t) { return VarcharType.SYSTEM_DEFAULT; case 'B': return BooleanType.INSTANCE; + case 'C': + return DateV2Type.INSTANCE; default: return BigIntType.INSTANCE; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java new file mode 100644 index 00000000000000..01502bac522cfa --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; +import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper { + + @Test + public void test() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + FoldConstantRule.INSTANCE, + SimplifyInPredicate.INSTANCE + )); + Map mem = Maps.newHashMap(); + Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')"); + rewrittenExpression = typeCoercion(replaceUnboundSlot(rewrittenExpression, mem)); + rewrittenExpression = executor.rewrite(rewrittenExpression, context); + Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))"); + expectedExpression = replaceUnboundSlot(expectedExpression, mem); + executor = new ExpressionRuleExecutor(ImmutableList.of( + FoldConstantRule.INSTANCE + )); + expectedExpression = executor.rewrite(expectedExpression, context); + Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql()); + } + +} diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query83.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query83.out index f1f3cfaf941eab..90df3709ecda1b 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query83.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query83.out @@ -30,7 +30,7 @@ PhysicalResultSink --------------------------------PhysicalOlapScan[date_dim] ------------------------------PhysicalDistribute --------------------------------PhysicalProject -----------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00)) +----------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11)) ------------------------------------PhysicalOlapScan[date_dim] ------------hashJoin[INNER_JOIN](sr_items.item_id = wr_items.item_id) --------------PhysicalProject @@ -57,7 +57,7 @@ PhysicalResultSink ----------------------------------PhysicalOlapScan[date_dim] --------------------------------PhysicalDistribute ----------------------------------PhysicalProject -------------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00)) +------------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11)) --------------------------------------PhysicalOlapScan[date_dim] --------------PhysicalProject ----------------hashAgg[GLOBAL] @@ -83,6 +83,6 @@ PhysicalResultSink ----------------------------------PhysicalOlapScan[date_dim] --------------------------------PhysicalDistribute ----------------------------------PhysicalProject -------------------------------------filter(cast(d_date as DATETIMEV2(0)) IN (2001-06-06 00:00:00, 2001-09-02 00:00:00, 2001-11-11 00:00:00)) +------------------------------------filter(d_date IN (2001-06-06, 2001-09-02, 2001-11-11)) --------------------------------------PhysicalOlapScan[date_dim] diff --git a/regression-test/suites/nereids_syntax_p0/test_simplify_in_predicate.groovy b/regression-test/suites/nereids_syntax_p0/test_simplify_in_predicate.groovy new file mode 100644 index 00000000000000..0079d5a2bdeaa3 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/test_simplify_in_predicate.groovy @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_simplify_in_predicate") { + sql "set enable_nereids_planner=true" + sql 'set enable_fallback_to_original_planner=false;' + sql 'drop table if exists test_simplify_in_predicate_t' + sql """CREATE TABLE IF NOT EXISTS `test_simplify_in_predicate_t` ( + a DATE NOT NULL + ) ENGINE=OLAP + UNIQUE KEY (`a`) + DISTRIBUTED BY HASH(`a`) BUCKETS 120 + PROPERTIES ( + "replication_num" = "1", + "in_memory" = "false", + "compression" = "LZ4" + );""" + sql """insert into test_simplify_in_predicate_t values( "2023-06-06" );""" + + explain { + sql "verbose select * from test_simplify_in_predicate_t where a in ('1992-01-31', '1992-02-01', '1992-02-02', '1992-02-03', '1992-02-04');" + notContains "CAST" + } +} \ No newline at end of file