Skip to content

Commit

Permalink
fix poc
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 5, 2024
1 parent f04e6be commit 2a3a9fa
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

Expand Down Expand Up @@ -127,7 +128,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
materializationContext.getShuttledExprToScanExprMapping(),
viewToQuerySlotMapping,
queryStructInfo.getTableBitSet(),
cascadesContext);
ImmutableMap.of(), cascadesContext);
boolean isRewrittenQueryExpressionValid = true;
if (!rewrittenQueryExpressions.isEmpty()) {
List<NamedExpression> projects = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -50,7 +52,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
materializationContext.getShuttledExprToScanExprMapping(),
targetToSourceMapping,
queryStructInfo.getTableBitSet(),
cascadesContext
ImmutableMap.of(), cascadesContext
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand Down Expand Up @@ -245,7 +245,9 @@ protected List<Plan> doRewrite(StructInfo queryStructInfo, CascadesContext casca
// Try to rewrite compensate predicates by using mv scan
List<Expression> rewriteCompensatePredicates = rewriteExpression(compensatePredicates.toList(),
queryPlan, materializationContext.getShuttledExprToScanExprMapping(),
viewToQuerySlotMapping, queryStructInfo.getTableBitSet(), cascadesContext);
viewToQuerySlotMapping, queryStructInfo.getTableBitSet(),
compensatePredicates.getRangePredicateMap(),
cascadesContext);
if (rewriteCompensatePredicates.isEmpty()) {
materializationContext.recordFailReason(queryStructInfo,
"Rewrite compensate predicate by view fail",
Expand Down Expand Up @@ -567,7 +569,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInf
*/
protected List<Expression> rewriteExpression(List<? extends Expression> sourceExpressionsToWrite, Plan sourcePlan,
ExpressionMapping targetExpressionMapping, SlotMapping targetToSourceMapping, BitSet sourcePlanBitSet,
CascadesContext cascadesContext) {
Map<Expression, Literal> shuttledQueryMap, CascadesContext cascadesContext) {
// Firstly, rewrite the target expression using source with inverse mapping
// then try to use the target expression to represent the query. if any of source expressions
// can not be represented by target expressions, return null.
Expand All @@ -586,7 +588,7 @@ protected List<Expression> rewriteExpression(List<? extends Expression> sourceEx
rewrittenExpressions.add(expressionShuttledToRewrite);
continue;
}
final Set<Slot> slotsToRewrite =
final Set<Expression> slotsToRewrite =
expressionShuttledToRewrite.collectToSet(expression -> expression instanceof Slot);

final Set<SlotReference> variants =
Expand All @@ -595,34 +597,49 @@ protected List<Expression> rewriteExpression(List<? extends Expression> sourceEx
extendMappingByVariant(variants, targetToTargetReplacementMappingQueryBased);
Expression replacedExpression = ExpressionUtils.replace(expressionShuttledToRewrite,
targetToTargetReplacementMappingQueryBased);
Set<Slot> replacedExpressionSlotQueryUsed = replacedExpression.collect(slotsToRewrite::contains);
Set<Expression> replacedExpressionSlotQueryUsed = replacedExpression.collect(slotsToRewrite::contains);
if (!replacedExpressionSlotQueryUsed.isEmpty()) {
// if contains any slot to rewrite, which means can not be rewritten by target,
// shuttled query expr is slot#0 > '2024-01-01' but mv plan output is date_trunc(slot#0, 'day')
// expressionShuttledToRewrite is slot#0 > '2024-01-01' but mv plan output is date_trunc(slot#0, 'day')
// which would try to rewrite
// slotDateTruncMap is {date_trunc(slot#0, 'day') : mv_scan_date_trunc_slot#10}
Map<Slot, DateTrunc> slotDateTruncMap = new HashMap<>();
// paramExpressionToDateTruncMap is {slot#0 : date_trunc(slot#0, 'day')}
Map<Expression, DateTrunc> paramExpressionToDateTruncMap = new HashMap<>();
targetToTargetReplacementMappingQueryBased.keySet().forEach(expr -> {
if (expr instanceof DateTrunc && expr.child(0) instanceof Slot) {
slotDateTruncMap.put((Slot) expr.child(0), (DateTrunc) expr);
if (expr instanceof DateTrunc) {
paramExpressionToDateTruncMap.put(expr.child(0), (DateTrunc) expr);
}
});
Expression queryExpr = expressionShuttledToRewrite.child(0);
Map<Expression, Literal> shuttledQueryParamToExpressionMap = new HashMap<>();
// TODO: 2024/12/5 optimize performance
for (Map.Entry<Expression, Literal> expressionEntry : shuttledQueryMap.entrySet()) {
Expression shuttledQueryParamExpression = ExpressionUtils.shuttleExpressionWithLineage(
expressionEntry.getKey(), sourcePlan, sourcePlanBitSet);
shuttledQueryParamToExpressionMap.put(shuttledQueryParamExpression.child(0) instanceof Literal
? shuttledQueryParamExpression.child(1) : shuttledQueryParamExpression.child(0),
expressionEntry.getValue());
}

if (slotDateTruncMap.isEmpty()) {
//mv date_trunc slot can not offer slot for query,
if (paramExpressionToDateTruncMap.isEmpty() || shuttledQueryMap.isEmpty()
|| !shuttledQueryMap.containsKey(expressionShuttledToRewrite)
|| !paramExpressionToDateTruncMap.containsKey(queryExpr)) {
// mv date_trunc expression can not offer expression for query,
// can not try to rewrite by date_trunc, bail out
return ImmutableList.of();
}
Expression replacedWithDateTrunc =
ExpressionUtils.replace(expressionShuttledToRewrite, slotDateTruncMap);
// check date_trunc(slot#0, 'day') > '2024-01-01' can simplify
replacedWithDateTrunc = new ExpressionOptimization().rewrite(replacedWithDateTrunc,

Map<Expression, Expression> datetruncMap = new HashMap<>();
Literal queryLiteral = shuttledQueryMap.get(expressionShuttledToRewrite);
datetruncMap.put(queryExpr, queryLiteral);
Expression replacedWithLiteral = ExpressionUtils.replace(
paramExpressionToDateTruncMap.get(queryExpr), datetruncMap);
Expression foldedExpressionWithLiteral = FoldConstantRuleOnFE.evaluate(replacedWithLiteral,
new ExpressionRewriteContext(cascadesContext));
if (replacedWithDateTrunc.equals(expressionShuttledToRewrite)) {
if (foldedExpressionWithLiteral.equals(queryLiteral)) {
// after date_trunc simplify if equals to original expression, could rewritten by mv
replacedExpression = ExpressionUtils.replace(expressionShuttledToRewrite,
targetToTargetReplacementMappingQueryBased,
slotDateTruncMap);
paramExpressionToDateTruncMap);
}
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
return ImmutableList.of();
Expand Down Expand Up @@ -794,7 +811,7 @@ protected SplitPredicate predicatesCompensate(
viewToQuerySlotMapping,
comparisonResult);
// range compensate
final Set<Expression> rangeCompensatePredicates = Predicates.compensateRangePredicate(
final Map<Expression, Literal> rangeCompensatePredicates = Predicates.compensateRangePredicate(
queryStructInfo,
viewStructInfo,
viewToQuerySlotMapping,
Expand All @@ -811,15 +828,17 @@ protected SplitPredicate predicatesCompensate(
return SplitPredicate.INVALID_INSTANCE;
}
if (equalCompensateConjunctions.stream().anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| rangeCompensatePredicates.stream().anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| rangeCompensatePredicates.keySet().stream()
.anyMatch(expr -> expr.containsType(AggregateFunction.class))
|| residualCompensatePredicates.stream().anyMatch(expr ->
expr.containsType(AggregateFunction.class))) {
return SplitPredicate.INVALID_INSTANCE;
}
return SplitPredicate.of(equalCompensateConjunctions.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(equalCompensateConjunctions),
rangeCompensatePredicates.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(rangeCompensatePredicates),
: ExpressionUtils.and(rangeCompensatePredicates.keySet()),
rangeCompensatePredicates.isEmpty() ? ImmutableMap.of() : rangeCompensatePredicates,
residualCompensatePredicates.isEmpty() ? BooleanLiteral.TRUE
: ExpressionUtils.and(residualCompensatePredicates));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.List;
Expand All @@ -51,7 +52,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
materializationContext.getShuttledExprToScanExprMapping(),
targetToSourceMapping,
queryStructInfo.getTableBitSet(),
cascadesContext
ImmutableMap.of(), cascadesContext
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -139,7 +145,7 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
/**
* compensate range predicates
*/
public static Set<Expression> compensateRangePredicate(StructInfo queryStructInfo,
public static Map<Expression, Literal> compensateRangePredicate(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult,
Expand All @@ -159,16 +165,28 @@ public static Set<Expression> compensateRangePredicate(StructInfo queryStructInf
Sets.difference(viewRangeQueryBasedSet, queryRangeSet).copyInto(differentExpressions);
// the range predicate in query and view is same, don't need to compensate
if (differentExpressions.isEmpty()) {
return differentExpressions;
return ImmutableMap.of();
}
// try to normalize the different expressions
Set<Expression> normalizedExpressions =
normalizeExpression(ExpressionUtils.and(differentExpressions), cascadesContext);
if (!queryRangeSet.containsAll(normalizedExpressions)) {
// normalized expressions is not in query, can not compensate
// todo compensate if date_trunc
return null;
}
return normalizedExpressions;
Map<Expression, Literal> normalizedExpressionsWithLiteral = new HashMap<>();
for (Expression expression : normalizedExpressions) {
Set<Literal> literalSet = expression.collect(expressionTreeNode -> expressionTreeNode instanceof Literal);
if (!(expression instanceof ComparisonPredicate)
|| (expression instanceof GreaterThan || expression instanceof LessThan)
|| literalSet.size() != 1) {
normalizedExpressionsWithLiteral.put(expression, null);
continue;
}
normalizedExpressionsWithLiteral.put(expression, literalSet.iterator().next());
}
return normalizedExpressionsWithLiteral;
}

private static Set<Expression> normalizeExpression(Expression expression, CascadesContext cascadesContext) {
Expand Down Expand Up @@ -220,14 +238,19 @@ public String toString() {
*/
public static final class SplitPredicate {
public static final SplitPredicate INVALID_INSTANCE =
SplitPredicate.of(null, null, null);
SplitPredicate.of(null, null, null, null);
private final Optional<Expression> equalPredicate;
private final Optional<Expression> rangePredicate;
private final Optional<Map<Expression, Literal>> rangePredicateMap;
private final Optional<Expression> residualPredicate;

public SplitPredicate(Expression equalPredicate, Expression rangePredicate, Expression residualPredicate) {
public SplitPredicate(Expression equalPredicate,
Expression rangePredicate,
Map<Expression, Literal> rangePredicateMap,
Expression residualPredicate) {
this.equalPredicate = Optional.ofNullable(equalPredicate);
this.rangePredicate = Optional.ofNullable(rangePredicate);
this.rangePredicateMap = Optional.ofNullable(rangePredicateMap);
this.residualPredicate = Optional.ofNullable(residualPredicate);
}

Expand All @@ -239,6 +262,10 @@ public Expression getRangePredicate() {
return rangePredicate.orElse(BooleanLiteral.TRUE);
}

public Map<Expression, Literal> getRangePredicateMap() {
return rangePredicateMap.orElse(ImmutableMap.of());
}

public Expression getResidualPredicate() {
return residualPredicate.orElse(BooleanLiteral.TRUE);
}
Expand All @@ -248,8 +275,9 @@ public Expression getResidualPredicate() {
*/
public static SplitPredicate of(Expression equalPredicates,
Expression rangePredicates,
Map<Expression, Literal> rangePredicateSet,
Expression residualPredicates) {
return new SplitPredicate(equalPredicates, rangePredicates, residualPredicates);
return new SplitPredicate(equalPredicates, rangePredicates, rangePredicateSet, residualPredicates);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public Predicates.SplitPredicate getSplitPredicate() {
return Predicates.SplitPredicate.of(
equalPredicates.isEmpty() ? null : ExpressionUtils.and(equalPredicates),
rangePredicates.isEmpty() ? null : ExpressionUtils.and(rangePredicates),
null,
residualPredicates.isEmpty() ? null : ExpressionUtils.and(residualPredicates));
}

Expand Down

0 comments on commit 2a3a9fa

Please sign in to comment.