Skip to content

Commit

Permalink
[Feature] (Part3) Support creating materialized views with common agg…
Browse files Browse the repository at this point in the history
…regate state functions (StarRocks#51510)

Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing committed Oct 29, 2024
1 parent 025829e commit 1ae593a
Show file tree
Hide file tree
Showing 28 changed files with 1,582 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ private Expr analyzeExpr(SelectAnalyzer.RewriteAliasVisitor visitor,
Expr newExpr = defineExpr.clone(smap);
newExpr = newExpr.accept(visitor, null);
newExpr = Expr.analyzeAndCastFold(newExpr);
if (!newExpr.getType().equals(type)) {
Type newType = newExpr.getType();
if (!type.isFullyCompatible(newType)) {
newExpr = new CastExpr(type, newExpr);
}
return newExpr;
Expand Down
2 changes: 1 addition & 1 deletion fe/fe-core/src/main/java/com/starrocks/catalog/Column.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public Column(Column column) {
this.name = column.getName();
this.columnId = column.getColumnId();
this.type = column.type;
this.type.setAggStateDesc(this.aggStateDesc);
this.type.setAggStateDesc(column.aggStateDesc);
this.aggregationType = column.getAggregationType();
this.isAggregationTypeImplicit = column.isAggregationTypeImplicit();
this.isKey = column.isKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ public AggStateCombinator(AggStateCombinator other) {

public static Optional<AggStateCombinator> of(AggregateFunction aggFunc) {
try {
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType();
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType().clone();
FunctionName funcName = new FunctionName(aggFunc.functionName() + FunctionSet.AGG_STATE_SUFFIX);
AggStateCombinator aggStateFunc = new AggStateCombinator(funcName, Arrays.asList(aggFunc.getArgs()),
intermediateType);
aggStateFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateFunc.setAggStateDesc(new AggStateDesc(aggFunc));

AggStateDesc aggStateDesc = new AggStateDesc(aggFunc);
aggStateFunc.setAggStateDesc(aggStateDesc);
// `agg_state` function's type will contain agg state desc.
intermediateType.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateFunc.functionName());
return Optional.of(aggStateFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateMergeCombinator> of(AggregateFunction aggFunc) {
new AggStateMergeCombinator(functionName, imtermediateType, aggFunc.getReturnType());
aggStateMergeFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateMergeFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateMergeFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateMergeFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateMergeFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateMergeFunc.functionName());
return Optional.of(aggStateMergeFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateUnionCombinator> of(AggregateFunction aggFunc) {
new AggStateUnionCombinator(functionName, intermediateType);
aggStateUnionFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateUnionFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateUnionFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateUnionFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateUnionFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateUnionFunc.functionName());
return Optional.of(aggStateUnionFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session,
return null;
}
AggregateFunction aggFunc = (AggregateFunction) argFn;
if (aggFunc.getNumArgs() == 1 && argumentTypes[0].isDecimalOfAnyVersion()) {
if (aggFunc.getNumArgs() == 1) {
// only copy argument if it's a decimal type
AggregateFunction argFnCopy = (AggregateFunction) aggFunc.copy();
argFnCopy.setArgsType(argumentTypes);
Expand Down Expand Up @@ -208,7 +208,9 @@ private static AggregateFunction getAggStateFunction(ConnectContext session,
if (!(fn instanceof AggregateFunction)) {
return null;
}
return (AggregateFunction) fn;
AggregateFunction result = (AggregateFunction) fn.copy();
result.setAggStateDesc(aggStateDesc);
return result;
}

private static Type[] getNewArgumentTypes(Type[] origArgTypes, String argFnName, Type arg0Type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.starrocks.analysis.CaseWhenClause;
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.FunctionParams;
import com.starrocks.analysis.IntLiteral;
import com.starrocks.analysis.IsNullPredicate;
import com.starrocks.analysis.OrderByElement;
Expand All @@ -58,6 +59,8 @@
import com.starrocks.catalog.Table;
import com.starrocks.catalog.Type;
import com.starrocks.catalog.View;
import com.starrocks.catalog.combinator.AggStateDesc;
import com.starrocks.catalog.combinator.AggStateUnionCombinator;
import com.starrocks.common.ErrorCode;
import com.starrocks.common.ErrorReport;
import com.starrocks.common.FeConstants;
Expand All @@ -68,6 +71,7 @@
import com.starrocks.sql.analyzer.AnalyzerUtils;
import com.starrocks.sql.analyzer.ExpressionAnalyzer;
import com.starrocks.sql.analyzer.Field;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.analyzer.RelationFields;
import com.starrocks.sql.analyzer.RelationId;
import com.starrocks.sql.analyzer.Scope;
Expand Down Expand Up @@ -250,7 +254,7 @@ public Map<String, Expr> parseDefineExprWithoutAnalyze(String originalSql) {
case FunctionSet.HLL_UNION:
case FunctionSet.PERCENTILE_UNION:
case FunctionSet.COUNT: {
MVColumnItem item = buildAggColumnItem(selectListItem, slots);
MVColumnItem item = buildAggColumnItem(new ConnectContext(), selectListItem, slots);
expr = item.getDefineExpr();
name = item.getName();
break;
Expand Down Expand Up @@ -337,7 +341,7 @@ public void analyze(ConnectContext context) {
if (!(selectRelation.getRelation() instanceof TableRelation)) {
throw new UnsupportedMVException("Materialized view query statement only support direct query from table.");
}
int beginIndexOfAggregation = genColumnAndSetIntoStmt(table, selectRelation);
int beginIndexOfAggregation = genColumnAndSetIntoStmt(context, table, selectRelation);
if (selectRelation.isDistinct() || selectRelation.hasAggregation()) {
setMvKeysType(KeysType.AGG_KEYS);
}
Expand Down Expand Up @@ -409,7 +413,7 @@ private void analyzeExprWithTableAlias(ConnectContext context,
.collect(Collectors.toList()))), context);
}

private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation) {
private int genColumnAndSetIntoStmt(ConnectContext context, Table table, SelectRelation selectRelation) {
List<MVColumnItem> mvColumnItemList = Lists.newArrayList();

boolean meetAggregate = false;
Expand Down Expand Up @@ -442,30 +446,33 @@ private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation)
&& ((FunctionCallExpr) selectListItemExpr).isAggregateFunction()) {
// Aggregate Function must match pattern.
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItemExpr;
String functionName = functionCallExpr.getFnName().getFunction();

MVColumnPattern mvColumnPattern =
CreateMaterializedViewStmt.FN_NAME_TO_PATTERN.get(functionName.toLowerCase());
if (mvColumnPattern == null) {
throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
String functionName = functionCallExpr.getFnName().getFunction().toLowerCase();
// current version not support count(distinct) function in creating materialized view
if (!isReplay && functionCallExpr.isDistinct()) {
throw new UnsupportedMVException(
"Materialized view does not support distinct function " + functionCallExpr.toSqlImpl());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
// eg: avg_union(avg_state(xxx))
} else {
MVColumnPattern mvColumnPattern = FN_NAME_TO_PATTERN.get(functionName);
if (mvColumnPattern == null) {

throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
}
}
if (beginIndexOfAggregation == -1) {
beginIndexOfAggregation = i;
}
meetAggregate = true;

mvColumnItem = buildAggColumnItem(selectListItem, slots);
mvColumnItem = buildAggColumnItem(context, selectListItem, slots);
if (!mvColumnNameSet.add(mvColumnItem.getName())) {
ErrorReport.reportSemanticException(ErrorCode.ERR_DUP_FIELDNAME, mvColumnItem.getName());
}
Expand Down Expand Up @@ -527,17 +534,68 @@ private MVColumnItem buildNonAggColumnItem(SelectListItem selectListItem,
type = AnalyzerUtils.transformTableColumnType(type, false);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, false, defineExpr,
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, null, false, defineExpr,
defineExpr.isNullable(), baseColumnNames);
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
private MVColumnItem buildAggColumnItem(ConnectContext context,
SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr node = (FunctionCallExpr) selectListItem.getExpr();
String functionName = node.getFnName().getFunction();
Preconditions.checkState(node.getChildren().size() == 1, "Aggregate function only support one child");

if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
if (Strings.isNullOrEmpty(selectListItem.getAlias())) {
throw new SemanticException("Create materialized view non-slot ref expression should have an alias:" +
selectListItem.getExpr());
}

Expr defineExpr = node.getChild(0);
List<Type> argTypes = node.getChildren().stream().map(Expr::getType).collect(Collectors.toList());
Type arg0Type = argTypes.get(0);
if (arg0Type.getAggStateDesc() == null) {
throw new UnsupportedMVException("Unsupported function:" + functionName + ", cannot find agg state desc from " +
"arg0");
}
FunctionParams params = new FunctionParams(false, Lists.newArrayList());
Type[] argumentTypes = argTypes.toArray(Type[]::new);
Boolean[] isArgumentConstants = argTypes.stream().map(x -> false).toArray(Boolean[]::new);
Function function = FunctionAnalyzer.getAnalyzedAggregateFunction(context, functionName,
params, argumentTypes, isArgumentConstants, NodePosition.ZERO);
if (function == null || !(function instanceof AggStateUnionCombinator)) {
throw new UnsupportedMVException("Unsupported function:" + functionName);
}
AggStateUnionCombinator aggFunction = (AggStateUnionCombinator) function;
String mvColumnName = MVUtils.getMVColumnName(selectListItem.getAlias());
AggStateDesc aggStateDesc = aggFunction.getAggStateDesc();
Type type = aggFunction.getIntermediateTypeOrReturnType();
if (type.isWildcardDecimal()) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
if (aggStateDesc.getArgTypes().stream().anyMatch(t -> t.isWildcardDecimal())) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
AggregateType mvAggregateType = AggregateType.AGG_STATE_UNION;
Type finalType = AnalyzerUtils.transformTableColumnType(type, false);
return new MVColumnItem(mvColumnName, finalType, mvAggregateType, aggStateDesc, false,
defineExpr, aggStateDesc.getResultNullable(), baseColumnNames);
} else {
return buildAggColumnItemWithPattern(selectListItem, baseSlotRefs);
}
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItemWithPattern(SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItem.getExpr();
String functionName = functionCallExpr.getFnName().getFunction();
Preconditions.checkState(functionCallExpr.getChildren().size() == 1, "Aggregate function only support one child");
Expr defineExpr = functionCallExpr.getChild(0);
AggregateType mvAggregateType = null;
Type baseType = defineExpr.getType();
Expand Down Expand Up @@ -640,8 +698,8 @@ private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
String.format("Invalid aggregate function '%s' for '%s'", mvAggregateType, type));
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, false,
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, null, false,
defineExpr, functionCallExpr.isNullable(), baseColumnNames);
}

Expand Down
12 changes: 9 additions & 3 deletions fe/fe-core/src/main/java/com/starrocks/sql/ast/MVColumnItem.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
import com.starrocks.catalog.Column;
import com.starrocks.catalog.OlapTable;
import com.starrocks.catalog.Type;
import com.starrocks.catalog.combinator.AggStateDesc;

import java.util.Set;

import static com.starrocks.catalog.Column.COLUMN_UNIQUE_ID_INIT_VALUE;

/**
* This is a result of semantic analysis for AddMaterializedViewClause.
* It is used to construct real mv column in MaterializedViewHandler.
Expand All @@ -54,16 +57,19 @@ public class MVColumnItem {
private Type type;
private boolean isKey;
private AggregateType aggregationType;
private AggStateDesc aggStateDesc;
private boolean isAllowNull;
private boolean isAggregationTypeImplicit;
private Expr defineExpr;
private Set<String> baseColumnNames;

public MVColumnItem(String name, Type type, AggregateType aggregateType, boolean isAggregationTypeImplicit,
public MVColumnItem(String name, Type type, AggregateType aggregateType, AggStateDesc aggStateDesc,
boolean isAggregationTypeImplicit,
Expr defineExpr, boolean isAllowNull, Set<String> baseColumnNames) {
this.name = name;
this.type = type;
this.aggregationType = aggregateType;
this.aggStateDesc = aggStateDesc;
this.isAggregationTypeImplicit = isAggregationTypeImplicit;
this.defineExpr = defineExpr;
this.isAllowNull = isAllowNull;
Expand Down Expand Up @@ -124,8 +130,8 @@ public Column toMVColumn(OlapTable olapTable) {
Column result;
boolean hasUniqueId = olapTable.getMaxColUniqueId() >= 0;
if (baseColumn == null) {
result = new Column(name, type, isKey, aggregationType, isAllowNull,
null, "");
result = new Column(name, type, isKey, aggregationType, aggStateDesc, isAllowNull,
null, "", COLUMN_UNIQUE_ID_INIT_VALUE);
if (defineExpr != null) {
result.setDefineExpr(defineExpr);
}
Expand Down
Loading

0 comments on commit 1ae593a

Please sign in to comment.