Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-34702][table-planner] Refactor Deduplicate optimization to defer to StreamPhysicalRank for valid StreamExecDeduplicate node conversion to avoid exceptions #25380

Merged
merged 7 commits into from
Sep 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalChangelogNormalize;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCorrelateBase;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDataStreamScan;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDeduplicate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDropUpdateBefore;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExpand;
Expand Down Expand Up @@ -184,8 +183,6 @@ public StreamPhysicalRel visit(
return visitOverAggregate((StreamPhysicalOverAggregateBase) rel, requireDeterminism);
} else if (rel instanceof StreamPhysicalRank) {
return visitRank((StreamPhysicalRank) rel, requireDeterminism);
} else if (rel instanceof StreamPhysicalDeduplicate) {
return visitDeduplicate((StreamPhysicalDeduplicate) rel, requireDeterminism);
} else if (rel instanceof StreamPhysicalWindowDeduplicate) {
return visitWindowDeduplicate(
(StreamPhysicalWindowDeduplicate) rel, requireDeterminism);
Expand Down Expand Up @@ -677,22 +674,6 @@ private StreamPhysicalRel visitRank(
}
}

private StreamPhysicalRel visitDeduplicate(
final StreamPhysicalDeduplicate dedup, final ImmutableBitSet requireDeterminism) {
// output row type same as input and does not change output columns' order
if (inputInsertOnly(dedup)) {
// similar to rank, output is deterministic when input is insert only, so required
// determinism always be satisfied here.
return transmitDeterminismRequirement(dedup, NO_REQUIRED_DETERMINISM);
} else {
// Deduplicate always has unique key currently(exec node has null check and inner
// state only support data with keys), so only pass the left columns of required
// determinism to input.
return transmitDeterminismRequirement(
dedup, requireDeterminism.except(ImmutableBitSet.of(dedup.getUniqueKeys())));
}
}

private StreamPhysicalRel visitWindowDeduplicate(
final StreamPhysicalWindowDeduplicate winDedup,
final ImmutableBitSet requireDeterminism) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecDeduplicate;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalSnapshot;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDeduplicate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRank;
import org.apache.flink.table.planner.plan.rules.common.CommonTemporalTableJoinRule;
import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil;
Expand Down Expand Up @@ -57,8 +57,8 @@
* a table source or a view only if it contains the unique key and time attribute.
*
* <p>Flink supports extract the primary key and row time attribute from the view if the view comes
* from {@link StreamPhysicalRank} node which can convert to a {@link StreamPhysicalDeduplicate}
* node.
* from {@link StreamPhysicalRank} node which can convert to a {@link StreamExecDeduplicate} node
* finally.
*/
@Value.Enclosing
public class TemporalJoinRewriteWithUniqueKeyRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,26 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
val input = rank.getInput
val rankFunColumnIndex = RankUtil.getRankNumberColumnIndex(rank).getOrElse(-1)
if (rankFunColumnIndex < 0) {
mq.areColumnsUnique(input, columns, ignoreNulls)
if (RankUtil.isDeduplication(rank)) {
columns != null && util.Arrays.equals(columns.toArray, rank.partitionKey.toArray)
} else {
val childColumns = columns.clear(rankFunColumnIndex)
val isChildColumnsUnique = mq.areColumnsUnique(input, childColumns, ignoreNulls)
if (isChildColumnsUnique != null && isChildColumnsUnique) {
true
val input = rank.getInput

val rankFunColumnIndex = RankUtil.getRankNumberColumnIndex(rank).getOrElse(-1)
if (rankFunColumnIndex < 0) {
mq.areColumnsUnique(input, columns, ignoreNulls)
} else {
rank.rankType match {
case RankType.ROW_NUMBER =>
val fields = columns.toArray
(rank.partitionKey.toArray :+ rankFunColumnIndex).forall(fields.contains(_))
case _ => false
val childColumns = columns.clear(rankFunColumnIndex)
val isChildColumnsUnique = mq.areColumnsUnique(input, childColumns, ignoreNulls)
if (isChildColumnsUnique != null && isChildColumnsUnique) {
true
} else {
rank.rankType match {
case RankType.ROW_NUMBER =>
val fields = columns.toArray
(rank.partitionKey.toArray :+ rankFunColumnIndex).forall(fields.contains(_))
case _ => false
}
}
}
}
Expand All @@ -277,14 +282,6 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = mq.areColumnsUnique(rel.getInput, columns, ignoreNulls)

def areColumnsUnique(
rel: StreamPhysicalDeduplicate,
mq: RelMetadataQuery,
columns: ImmutableBitSet,
ignoreNulls: Boolean): JBoolean = {
columns != null && util.Arrays.equals(columns.toArray, rel.getUniqueKeys)
}

def areColumnsUnique(
rel: StreamPhysicalChangelogNormalize,
mq: RelMetadataQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCo
import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, IntermediateRelTable, TableSourceTable}
import org.apache.flink.table.planner.plan.stats.{WithLower, WithUpper}
import org.apache.flink.table.planner.plan.utils.RankUtil
import org.apache.flink.types.RowKind

import org.apache.calcite.plan.hep.HepRelVertex
Expand Down Expand Up @@ -186,70 +187,78 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon
}

def getRelModifiedMonotonicity(rel: Rank, mq: RelMetadataQuery): RelModifiedMonotonicity = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputMonotonicity = fmq.getRelModifiedMonotonicity(rel.getInput)
rel match {
case physicalRank: StreamPhysicalRank if RankUtil.isDeduplication(rel) =>
getPhysicalRankModifiedMonotonicity(physicalRank, mq)

// If child monotonicity is null, we should return early.
if (inputMonotonicity == null) {
return null
}
case _ =>
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputMonotonicity = fmq.getRelModifiedMonotonicity(rel.getInput)

// if partitionBy a update field or partitionBy a field whose mono is null, just return null
if (rel.partitionKey.exists(e => inputMonotonicity.fieldMonotonicities(e) != CONSTANT)) {
return null
}
// If child monotonicity is null, we should return early.
if (inputMonotonicity == null) {
return null
}

val fieldCount = rel.getRowType.getFieldCount
// if partitionBy a update field or partitionBy a field whose mono is null, just return null
if (rel.partitionKey.exists(e => inputMonotonicity.fieldMonotonicities(e) != CONSTANT)) {
return null
}

// init current mono
val currentMonotonicity = notMonotonic(fieldCount)
// 1. partitionBy field is CONSTANT
rel.partitionKey.foreach(e => currentMonotonicity.fieldMonotonicities(e) = CONSTANT)
// 2. row number filed is CONSTANT
if (rel.outputRankNumber) {
currentMonotonicity.fieldMonotonicities(fieldCount - 1) = CONSTANT
}
// 3. time attribute field is increasing
(0 until fieldCount).foreach(
e => {
if (FlinkTypeFactory.isTimeIndicatorType(rel.getRowType.getFieldList.get(e).getType)) {
inputMonotonicity.fieldMonotonicities(e) = INCREASING
val fieldCount = rel.getRowType.getFieldCount

// init current mono
val currentMonotonicity = notMonotonic(fieldCount)
// 1. partitionBy field is CONSTANT
rel.partitionKey.foreach(e => currentMonotonicity.fieldMonotonicities(e) = CONSTANT)
// 2. row number filed is CONSTANT
if (rel.outputRankNumber) {
currentMonotonicity.fieldMonotonicities(fieldCount - 1) = CONSTANT
}
// 3. time attribute field is increasing
(0 until fieldCount).foreach(
e => {
if (FlinkTypeFactory.isTimeIndicatorType(rel.getRowType.getFieldList.get(e).getType)) {
inputMonotonicity.fieldMonotonicities(e) = INCREASING
}
})
val fieldCollations = rel.orderKey.getFieldCollations
if (fieldCollations.nonEmpty) {
// 4. process the first collation field, we can only deduce the first collation field
val firstCollation = fieldCollations.get(0)
// Collation field index in child node will be same with Rank node,
// see ProjectToLogicalProjectAndWindowRule for details.
val fieldMonotonicity =
inputMonotonicity.fieldMonotonicities(firstCollation.getFieldIndex)
val result = fieldMonotonicity match {
case SqlMonotonicity.INCREASING | SqlMonotonicity.CONSTANT
if firstCollation.direction == RelFieldCollation.Direction.DESCENDING =>
INCREASING
case SqlMonotonicity.DECREASING | SqlMonotonicity.CONSTANT
if firstCollation.direction == RelFieldCollation.Direction.ASCENDING =>
DECREASING
case _ => NOT_MONOTONIC
}
currentMonotonicity.fieldMonotonicities(firstCollation.getFieldIndex) = result
}
})
val fieldCollations = rel.orderKey.getFieldCollations
if (fieldCollations.nonEmpty) {
// 4. process the first collation field, we can only deduce the first collation field
val firstCollation = fieldCollations.get(0)
// Collation field index in child node will be same with Rank node,
// see ProjectToLogicalProjectAndWindowRule for details.
val fieldMonotonicity = inputMonotonicity.fieldMonotonicities(firstCollation.getFieldIndex)
val result = fieldMonotonicity match {
case SqlMonotonicity.INCREASING | SqlMonotonicity.CONSTANT
if firstCollation.direction == RelFieldCollation.Direction.DESCENDING =>
INCREASING
case SqlMonotonicity.DECREASING | SqlMonotonicity.CONSTANT
if firstCollation.direction == RelFieldCollation.Direction.ASCENDING =>
DECREASING
case _ => NOT_MONOTONIC
}
currentMonotonicity.fieldMonotonicities(firstCollation.getFieldIndex) = result
}

currentMonotonicity
currentMonotonicity
}
}

def getRelModifiedMonotonicity(
rel: StreamPhysicalDeduplicate,
private def getPhysicalRankModifiedMonotonicity(
rank: StreamPhysicalRank,
mq: RelMetadataQuery): RelModifiedMonotonicity = {
if (allAppend(mq, rel.getInput)) {
if (rel.keepLastRow || rel.isRowtime) {
// Can't use RankUtil.canConvertToDeduplicate directly because modifyKindSetTrait is undefined.
if (allAppend(mq, rank.getInput)) {
if (RankUtil.keepLastDeduplicateRow(rank.orderKey) || rank.sortOnRowTime) {
val mono = new RelModifiedMonotonicity(
Array.fill(rel.getRowType.getFieldCount)(NOT_MONOTONIC))
rel.getUniqueKeys.foreach(e => mono.fieldMonotonicities(e) = CONSTANT)
Array.fill(rank.getRowType.getFieldCount)(NOT_MONOTONIC))
rank.partitionKey.toArray.foreach(e => mono.fieldMonotonicities(e) = CONSTANT)
mono
} else {
// FirstRow do not generate updates.
new RelModifiedMonotonicity(Array.fill(rel.getRowType.getFieldCount)(CONSTANT))
new RelModifiedMonotonicity(Array.fill(rank.getRowType.getFieldCount)(CONSTANT))
}
} else {
null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, TableSourceTable}
import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, RankUtil}
import org.apache.flink.table.runtime.groupwindow.NamedWindowProperty
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType}
import org.apache.flink.table.runtime.operators.rank.RankType
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts

import com.google.common.collect.ImmutableSet
Expand Down Expand Up @@ -290,19 +290,9 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu

def getRankUniqueKeys(rel: Rank, inputKeys: JSet[ImmutableBitSet]): JSet[ImmutableBitSet] = {
val rankFunColumnIndex = RankUtil.getRankNumberColumnIndex(rel).getOrElse(-1)
// for Rank node that can convert to Deduplicate, unique key is partition key
val canConvertToDeduplicate: Boolean = {
val rankRange = rel.rankRange
val isRowNumberType = rel.rankType == RankType.ROW_NUMBER
val isLimit1 = rankRange match {
case rankRange: ConstantRankRange =>
rankRange.getRankStart == 1 && rankRange.getRankEnd == 1
case _ => false
}
isRowNumberType && isLimit1
}

if (canConvertToDeduplicate) {
if (RankUtil.isDeduplication(rel)) {
// for Rank node that can convert to Deduplicate, unique key is partition key
val retSet = new JHashSet[ImmutableBitSet]
retSet.add(rel.partitionKey)
retSet
Expand All @@ -325,13 +315,6 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu
def getUniqueKeys(rel: Sort, mq: RelMetadataQuery, ignoreNulls: Boolean): JSet[ImmutableBitSet] =
mq.getUniqueKeys(rel.getInput, ignoreNulls)

def getUniqueKeys(
rel: StreamPhysicalDeduplicate,
mq: RelMetadataQuery,
ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
ImmutableSet.of(ImmutableBitSet.of(rel.getUniqueKeys.map(Integer.valueOf).toList))
}

def getUniqueKeys(
rel: StreamPhysicalChangelogNormalize,
mq: RelMetadataQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalGr
import org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalLookupJoin
import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.IntermediateRelTable
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil
import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, RankUtil}

import com.google.common.collect.ImmutableSet
import org.apache.calcite.plan.hep.HepRelVertex
Expand Down Expand Up @@ -91,23 +91,24 @@ class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
}

def getUpsertKeys(rel: Rank, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
val inputKeys = filterKeys(
FlinkRelMetadataQuery
.reuseOrCreate(mq)
.getUpsertKeys(rel.getInput),
rel.partitionKey)
FlinkRelMdUniqueKeys.INSTANCE.getRankUniqueKeys(rel, inputKeys)
rel match {
case rank: StreamPhysicalRank if RankUtil.isDeduplication(rel) =>
ImmutableSet.of(ImmutableBitSet.of(rank.partitionKey.toArray.map(Integer.valueOf).toList))
case _ =>
val inputKeys = filterKeys(
FlinkRelMetadataQuery
.reuseOrCreate(mq)
.getUpsertKeys(rel.getInput),
rel.partitionKey)
FlinkRelMdUniqueKeys.INSTANCE.getRankUniqueKeys(rel, inputKeys)
}
}

def getUpsertKeys(rel: Sort, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
filterKeys(
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput),
ImmutableBitSet.of(rel.getCollation.getKeys))

def getUpsertKeys(rel: StreamPhysicalDeduplicate, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
ImmutableSet.of(ImmutableBitSet.of(rel.getUniqueKeys.map(Integer.valueOf).toList))
}

def getUpsertKeys(
rel: StreamPhysicalChangelogNormalize,
mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
Expand Down
Loading