From 538d319c1d79f51a2b35402fb696b4d64e51ee1a Mon Sep 17 00:00:00 2001 From: tobe Date: Thu, 24 Feb 2022 19:25:09 +0800 Subject: [PATCH 1/2] Support window union with unsaferow opt --- .../openmldb/batch/nodes/WindowAggPlan.scala | 103 +++++++++++++++++- 1 file changed, 99 insertions(+), 4 deletions(-) diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala index 689981f8fdd..8e2d476f6c7 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala @@ -95,11 +95,20 @@ object WindowAggPlan { val internalRowRdd = repartitionDf.queryExecution.toRdd val zippedRdd = rowRdd.zip(internalRowRdd) - val outputInternalRowRdd = zippedRdd.mapPartitionsWithIndex { - case (partitionIndex, iter) => - val computer = WindowAggPlanUtil.createComputer(partitionIndex, hadoopConf, sparkFeConfig, windowAggConfig) - unsafeWindowAggIter(computer, iter, sparkFeConfig, windowAggConfig, outputSchema) + val outputInternalRowRdd = if (isWindowWithUnion) { + zippedRdd.mapPartitionsWithIndex { + case (partitionIndex, iter) => + val computer = WindowAggPlanUtil.createComputer(partitionIndex, hadoopConf, sparkFeConfig, windowAggConfig) + unsafeWindowAggIterWithUnionFlag(computer, iter, sparkFeConfig, windowAggConfig, outputSchema) + } + } else { + zippedRdd.mapPartitionsWithIndex { + case (partitionIndex, iter) => + val computer = WindowAggPlanUtil.createComputer(partitionIndex, hadoopConf, sparkFeConfig, windowAggConfig) + unsafeWindowAggIter(computer, iter, sparkFeConfig, windowAggConfig, outputSchema) + } } + SparkUtil.rddInternalRowToDf(ctx.getSparkSession, outputInternalRowRdd, outputSchema) } else { // isUnsafeRowOptimization is false @@ -430,6 +439,92 @@ object WindowAggPlan { } } + def unsafeWindowAggIterWithUnionFlag(computer: WindowComputer, + inputIter: Iterator[(Row, InternalRow)], + sqlConfig: OpenmldbBatchConfig, + config: WindowAggConfig, + outputSchema: StructType): Iterator[InternalRow] = { + val flagIdx = config.unionFlagIdx + var lastRow: Row = null + + // Take the iterator if the limit has been set + val limitInputIter = if (config.limitCnt > 0) inputIter.take(config.limitCnt) else inputIter + + if (config.partIdIdx != 0) { + val skewGroups = config.groupIdxs :+ config.partIdIdx + computer.resetGroupKeyComparator(skewGroups) + } + + val resIter = if (sqlConfig.enableWindowSkewOpt) { + limitInputIter.flatMap(zippedRow => { + + val row = zippedRow._1 + val internalRow = zippedRow._2 + + if (lastRow != null) { + computer.checkPartition(row, lastRow) + } + lastRow = row + + val orderKey = computer.extractKey(row) + val expandedFlag = row.getBoolean(config.expandedFlagIdx) + if (!isValidOrder(orderKey)) { + None + } else if (!expandedFlag) { + Some(computer.unsafeCompute(internalRow, orderKey, config.keepIndexColumn, config.unionFlagIdx, outputSchema)) + } else { + computer.bufferRowOnly(row, orderKey) + None + } + }) + } else { + limitInputIter.flatMap(zippedRow => { + + val row = zippedRow._1 + val internalRow = zippedRow._2 + + if (lastRow != null) { + computer.checkPartition(row, lastRow) + } + lastRow = row + + val orderKey = computer.extractKey(row) + if (isValidOrder(orderKey)) { + + val unionFlag = row.getBoolean(flagIdx) + if (unionFlag) { + if (sqlConfig.enableWindowSkewOpt) { + val expandedFlag = row.getBoolean(config.expandedFlagIdx) + if (!expandedFlag) { + Some(computer.unsafeCompute(internalRow, orderKey, config.keepIndexColumn, config.unionFlagIdx, + outputSchema)) + } else { + if (!config.instanceNotInWindow) { + computer.bufferRowOnly(row, orderKey) + } + None + } + } else { + Some(computer.unsafeCompute(internalRow, orderKey, config.keepIndexColumn, config.unionFlagIdx, + outputSchema)) + } + } else { + // secondary + computer.bufferRowOnly(row, orderKey) + None + } + } else { + None + } + }) + } + AutoDestructibleIterator(resIter) { + computer.delete() + } + } + + + def isValidOrder(key: Long): Boolean = { // TODO: Ignore the null value, maybe handle null in the future if (key == null) { From 8f92e06a51059928d055f784ea8ba805289ad30f Mon Sep 17 00:00:00 2001 From: tobe Date: Thu, 24 Feb 2022 19:25:56 +0800 Subject: [PATCH 2/2] Add unit test for window union with unsaferow opt --- .../unsafe/TestUnsafeWindowWithUnion.scala | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindowWithUnion.scala diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindowWithUnion.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindowWithUnion.scala new file mode 100644 index 00000000000..730de7d9316 --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindowWithUnion.scala @@ -0,0 +1,59 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end.unsafe + +import com._4paradigm.openmldb.batch.SparkTestSuite +import com._4paradigm.openmldb.batch.api.OpenmldbSession +import com._4paradigm.openmldb.batch.end2end.DataUtil +import com._4paradigm.openmldb.batch.utils.SparkUtil + +class TestUnsafeWindowWithUnion extends SparkTestSuite { + + override def customizedBefore(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", true) + } + + test("Test unsafe window") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val df = DataUtil.getTestDf(spark) + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText =""" + | SELECT sum(trans_amount) OVER w AS w_sum_amount FROM t1 + | WINDOW w AS ( + | UNION t1 + | PARTITION BY name + | ORDER BY trans_time + | ROWS BETWEEN 10 PRECEDING AND CURRENT ROW); + """.stripMargin + + val outputDf = sess.sql(sqlText) + val count = outputDf.count() + val expectedCount = df.count() + assert(count == expectedCount) + } + + override def customizedAfter(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", false) + } + +}