Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

10 30 merge #17

Draft
wants to merge 3 commits into
base: pinterest-integration-3.3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions be/src/exec/cross_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(
Expr::create_expr_trees(_pool, tnode.nestloop_join_node.join_conjuncts, &_join_conjuncts, state));
}

if (tnode.nestloop_join_node.__isset.interpolate_passthrough) {
_interpolate_passthrough = tnode.nestloop_join_node.interpolate_passthrough;
}
if (tnode.nestloop_join_node.__isset.sql_join_conjuncts) {
_sql_join_conjuncts = tnode.nestloop_join_node.sql_join_conjuncts;
}
Expand Down Expand Up @@ -619,6 +623,11 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> CrossJoinNode::_decompos
left_ops.emplace_back(std::make_shared<LimitOperatorFactory>(context->next_operator_id(), id(), limit()));
}

if (_interpolate_passthrough && !context->is_colocate_group()) {
left_ops = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), id(), left_ops,
context->degree_of_parallelism(), true);
}

if constexpr (std::is_same_v<BuildFactory, SpillableNLJoinBuildOperatorFactory>) {
may_add_chunk_accumulate_operator(left_ops, context, id());
}
Expand Down
1 change: 1 addition & 0 deletions be/src/exec/cross_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class CrossJoinNode final : public ExecNode {
std::vector<uint32_t> _buf_selective;

std::vector<RuntimeFilterBuildDescriptor*> _build_runtime_filters;
bool _interpolate_passthrough = false;
};

} // namespace starrocks
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ private Expr analyzeExpr(SelectAnalyzer.RewriteAliasVisitor visitor,
Expr newExpr = defineExpr.clone(smap);
newExpr = newExpr.accept(visitor, null);
newExpr = Expr.analyzeAndCastFold(newExpr);
if (!newExpr.getType().equals(type)) {
Type newType = newExpr.getType();
if (!type.isFullyCompatible(newType)) {
newExpr = new CastExpr(type, newExpr);
}
return newExpr;
Expand Down
2 changes: 1 addition & 1 deletion fe/fe-core/src/main/java/com/starrocks/catalog/Column.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public Column(Column column) {
this.name = column.getName();
this.columnId = column.getColumnId();
this.type = column.type;
this.type.setAggStateDesc(this.aggStateDesc);
this.type.setAggStateDesc(column.aggStateDesc);
this.aggregationType = column.getAggregationType();
this.isAggregationTypeImplicit = column.isAggregationTypeImplicit();
this.isKey = column.isKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ public AggStateCombinator(AggStateCombinator other) {

public static Optional<AggStateCombinator> of(AggregateFunction aggFunc) {
try {
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType();
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType().clone();
FunctionName funcName = new FunctionName(aggFunc.functionName() + FunctionSet.AGG_STATE_SUFFIX);
AggStateCombinator aggStateFunc = new AggStateCombinator(funcName, Arrays.asList(aggFunc.getArgs()),
intermediateType);
aggStateFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateFunc.setAggStateDesc(new AggStateDesc(aggFunc));

AggStateDesc aggStateDesc = new AggStateDesc(aggFunc);
aggStateFunc.setAggStateDesc(aggStateDesc);
// `agg_state` function's type will contain agg state desc.
intermediateType.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateFunc.functionName());
return Optional.of(aggStateFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateMergeCombinator> of(AggregateFunction aggFunc) {
new AggStateMergeCombinator(functionName, imtermediateType, aggFunc.getReturnType());
aggStateMergeFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateMergeFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateMergeFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateMergeFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateMergeFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateMergeFunc.functionName());
return Optional.of(aggStateMergeFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateUnionCombinator> of(AggregateFunction aggFunc) {
new AggStateUnionCombinator(functionName, intermediateType);
aggStateUnionFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateUnionFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateUnionFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateUnionFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateUnionFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateUnionFunc.functionName());
return Optional.of(aggStateUnionFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session,
return null;
}
AggregateFunction aggFunc = (AggregateFunction) argFn;
if (aggFunc.getNumArgs() == 1 && argumentTypes[0].isDecimalOfAnyVersion()) {
if (aggFunc.getNumArgs() == 1) {
// only copy argument if it's a decimal type
AggregateFunction argFnCopy = (AggregateFunction) aggFunc.copy();
argFnCopy.setArgsType(argumentTypes);
Expand Down Expand Up @@ -208,7 +208,9 @@ private static AggregateFunction getAggStateFunction(ConnectContext session,
if (!(fn instanceof AggregateFunction)) {
return null;
}
return (AggregateFunction) fn;
AggregateFunction result = (AggregateFunction) fn.copy();
result.setAggStateDesc(aggStateDesc);
return result;
}

private static Type[] getNewArgumentTypes(Type[] origArgTypes, String argFnName, Type arg0Type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ protected void toThrift(TPlanNode msg) {
String sqlJoinPredicate = otherJoinConjuncts.stream().map(Expr::toSql).collect(Collectors.joining(","));
msg.nestloop_join_node.setSql_join_conjuncts(sqlJoinPredicate);
}
SessionVariable sv = ConnectContext.get().getSessionVariable();
if (getCanLocalShuffle()) {
msg.nestloop_join_node.setInterpolate_passthrough(sv.isHashJoinInterpolatePassthrough());
}


if (!buildRuntimeFilters.isEmpty()) {
msg.nestloop_join_node.setBuild_runtime_filters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.starrocks.analysis.CaseWhenClause;
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.FunctionParams;
import com.starrocks.analysis.IntLiteral;
import com.starrocks.analysis.IsNullPredicate;
import com.starrocks.analysis.OrderByElement;
Expand All @@ -58,6 +59,8 @@
import com.starrocks.catalog.Table;
import com.starrocks.catalog.Type;
import com.starrocks.catalog.View;
import com.starrocks.catalog.combinator.AggStateDesc;
import com.starrocks.catalog.combinator.AggStateUnionCombinator;
import com.starrocks.common.ErrorCode;
import com.starrocks.common.ErrorReport;
import com.starrocks.common.FeConstants;
Expand All @@ -68,6 +71,7 @@
import com.starrocks.sql.analyzer.AnalyzerUtils;
import com.starrocks.sql.analyzer.ExpressionAnalyzer;
import com.starrocks.sql.analyzer.Field;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.analyzer.RelationFields;
import com.starrocks.sql.analyzer.RelationId;
import com.starrocks.sql.analyzer.Scope;
Expand Down Expand Up @@ -250,7 +254,7 @@ public Map<String, Expr> parseDefineExprWithoutAnalyze(String originalSql) {
case FunctionSet.HLL_UNION:
case FunctionSet.PERCENTILE_UNION:
case FunctionSet.COUNT: {
MVColumnItem item = buildAggColumnItem(selectListItem, slots);
MVColumnItem item = buildAggColumnItem(new ConnectContext(), selectListItem, slots);
expr = item.getDefineExpr();
name = item.getName();
break;
Expand Down Expand Up @@ -337,7 +341,7 @@ public void analyze(ConnectContext context) {
if (!(selectRelation.getRelation() instanceof TableRelation)) {
throw new UnsupportedMVException("Materialized view query statement only support direct query from table.");
}
int beginIndexOfAggregation = genColumnAndSetIntoStmt(table, selectRelation);
int beginIndexOfAggregation = genColumnAndSetIntoStmt(context, table, selectRelation);
if (selectRelation.isDistinct() || selectRelation.hasAggregation()) {
setMvKeysType(KeysType.AGG_KEYS);
}
Expand Down Expand Up @@ -409,7 +413,7 @@ private void analyzeExprWithTableAlias(ConnectContext context,
.collect(Collectors.toList()))), context);
}

private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation) {
private int genColumnAndSetIntoStmt(ConnectContext context, Table table, SelectRelation selectRelation) {
List<MVColumnItem> mvColumnItemList = Lists.newArrayList();

boolean meetAggregate = false;
Expand Down Expand Up @@ -442,30 +446,33 @@ private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation)
&& ((FunctionCallExpr) selectListItemExpr).isAggregateFunction()) {
// Aggregate Function must match pattern.
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItemExpr;
String functionName = functionCallExpr.getFnName().getFunction();

MVColumnPattern mvColumnPattern =
CreateMaterializedViewStmt.FN_NAME_TO_PATTERN.get(functionName.toLowerCase());
if (mvColumnPattern == null) {
throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
String functionName = functionCallExpr.getFnName().getFunction().toLowerCase();
// current version not support count(distinct) function in creating materialized view
if (!isReplay && functionCallExpr.isDistinct()) {
throw new UnsupportedMVException(
"Materialized view does not support distinct function " + functionCallExpr.toSqlImpl());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
// eg: avg_union(avg_state(xxx))
} else {
MVColumnPattern mvColumnPattern = FN_NAME_TO_PATTERN.get(functionName);
if (mvColumnPattern == null) {

throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
}
}
if (beginIndexOfAggregation == -1) {
beginIndexOfAggregation = i;
}
meetAggregate = true;

mvColumnItem = buildAggColumnItem(selectListItem, slots);
mvColumnItem = buildAggColumnItem(context, selectListItem, slots);
if (!mvColumnNameSet.add(mvColumnItem.getName())) {
ErrorReport.reportSemanticException(ErrorCode.ERR_DUP_FIELDNAME, mvColumnItem.getName());
}
Expand Down Expand Up @@ -527,17 +534,68 @@ private MVColumnItem buildNonAggColumnItem(SelectListItem selectListItem,
type = AnalyzerUtils.transformTableColumnType(type, false);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, false, defineExpr,
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, null, false, defineExpr,
defineExpr.isNullable(), baseColumnNames);
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
private MVColumnItem buildAggColumnItem(ConnectContext context,
SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr node = (FunctionCallExpr) selectListItem.getExpr();
String functionName = node.getFnName().getFunction();
Preconditions.checkState(node.getChildren().size() == 1, "Aggregate function only support one child");

if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
if (Strings.isNullOrEmpty(selectListItem.getAlias())) {
throw new SemanticException("Create materialized view non-slot ref expression should have an alias:" +
selectListItem.getExpr());
}

Expr defineExpr = node.getChild(0);
List<Type> argTypes = node.getChildren().stream().map(Expr::getType).collect(Collectors.toList());
Type arg0Type = argTypes.get(0);
if (arg0Type.getAggStateDesc() == null) {
throw new UnsupportedMVException("Unsupported function:" + functionName + ", cannot find agg state desc from " +
"arg0");
}
FunctionParams params = new FunctionParams(false, Lists.newArrayList());
Type[] argumentTypes = argTypes.toArray(Type[]::new);
Boolean[] isArgumentConstants = argTypes.stream().map(x -> false).toArray(Boolean[]::new);
Function function = FunctionAnalyzer.getAnalyzedAggregateFunction(context, functionName,
params, argumentTypes, isArgumentConstants, NodePosition.ZERO);
if (function == null || !(function instanceof AggStateUnionCombinator)) {
throw new UnsupportedMVException("Unsupported function:" + functionName);
}
AggStateUnionCombinator aggFunction = (AggStateUnionCombinator) function;
String mvColumnName = MVUtils.getMVColumnName(selectListItem.getAlias());
AggStateDesc aggStateDesc = aggFunction.getAggStateDesc();
Type type = aggFunction.getIntermediateTypeOrReturnType();
if (type.isWildcardDecimal()) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
if (aggStateDesc.getArgTypes().stream().anyMatch(t -> t.isWildcardDecimal())) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
AggregateType mvAggregateType = AggregateType.AGG_STATE_UNION;
Type finalType = AnalyzerUtils.transformTableColumnType(type, false);
return new MVColumnItem(mvColumnName, finalType, mvAggregateType, aggStateDesc, false,
defineExpr, aggStateDesc.getResultNullable(), baseColumnNames);
} else {
return buildAggColumnItemWithPattern(selectListItem, baseSlotRefs);
}
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItemWithPattern(SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItem.getExpr();
String functionName = functionCallExpr.getFnName().getFunction();
Preconditions.checkState(functionCallExpr.getChildren().size() == 1, "Aggregate function only support one child");
Expr defineExpr = functionCallExpr.getChild(0);
AggregateType mvAggregateType = null;
Type baseType = defineExpr.getType();
Expand Down Expand Up @@ -640,8 +698,8 @@ private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
String.format("Invalid aggregate function '%s' for '%s'", mvAggregateType, type));
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, false,
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, null, false,
defineExpr, functionCallExpr.isNullable(), baseColumnNames);
}

Expand Down
Loading