From 2e7fb76aca6a41a96c43b46b1be2194281c2166f Mon Sep 17 00:00:00 2001 From: tobe Date: Thu, 24 Feb 2022 17:39:17 +0800 Subject: [PATCH] feat: fix unsaferow for window (#1298) * Support config of debugShowNodeDf for openmldb-batch * Add rowToString in spark row util and add unit tests * Update offset for string and non-string columns in row codec * Add unit tests for unsafe row opt * Add data util for openmldb-batch unit tests --- Makefile | 1 + hybridse/src/codec/fe_row_codec.cc | 5 ++ .../openmldb/batch/OpenmldbBatchConfig.scala | 7 ++- .../openmldb/batch/SparkPlanner.scala | 5 +- .../openmldb/batch/nodes/WindowAggPlan.scala | 6 -- .../openmldb/batch/utils/SparkRowUtil.scala | 53 ++++++++++++++--- .../src/test/resources/log4j.properties | 50 ++-------------- .../openmldb/batch/end2end/DataUtil.scala | 51 +++++++++++++++++ .../end2end/unsafe/TestUnsafeProject.scala | 52 +++++++++++++++++ .../end2end/unsafe/TestUnsafeWindow.scala | 57 +++++++++++++++++++ .../batch/utils/TestSparkRowUtil.scala | 16 +++++- 11 files changed, 239 insertions(+), 64 deletions(-) create mode 100644 java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/DataUtil.scala create mode 100644 java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeProject.scala create mode 100644 java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindow.scala diff --git a/Makefile b/Makefile index 48bcbeb04e5..42a2d41e740 100644 --- a/Makefile +++ b/Makefile @@ -118,6 +118,7 @@ configure: thirdparty-fast openmldb-clean: rm -rf "$(OPENMLDB_BUILD_DIR)" + @cd java && ./mvnw clean THIRD_PARTY_BUILD_DIR ?= $(MAKEFILE_DIR)/.deps THIRD_PARTY_SRC_DIR ?= $(MAKEFILE_DIR)/thirdsrc diff --git a/hybridse/src/codec/fe_row_codec.cc b/hybridse/src/codec/fe_row_codec.cc index deabe95d7f4..aa56f8bf97b 100644 --- a/hybridse/src/codec/fe_row_codec.cc +++ b/hybridse/src/codec/fe_row_codec.cc @@ -893,6 +893,11 @@ RowFormat::RowFormat(const hybridse::codec::Schema* schema) next_str_pos_.insert( std::make_pair(string_field_cnt, string_field_cnt)); string_field_cnt += 1; + + if (FLAGS_enable_spark_unsaferow_format) { + // For UnsafeRowOpt, the offset should be added for string and non-string columns + offset += 8; + } } else { auto TYPE_SIZE_MAP = codec::GetTypeSizeMap(); auto it = TYPE_SIZE_MAP.find(column.type()); diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala index 2a94834787e..c34ab330103 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala @@ -29,7 +29,7 @@ class OpenmldbBatchConfig extends Serializable { @ConfigOption(name = "spark.sql.session.timeZone") var timeZone = "Asia/Shanghai" - // test mode 用于测试的时候验证相关问题 + // test mode @ConfigOption(name = "openmldb.test.tiny", doc = "控制读取表的数据条数,默认读全量数据") var tinyData: Long = -1 @@ -42,6 +42,10 @@ class OpenmldbBatchConfig extends Serializable { @ConfigOption(name = "openmldb.test.print", doc = "执行过程中允许打印数据") var print: Boolean = false + // Debug options + @ConfigOption(name = "openmldb.debug.show_node_df", doc = "Use Spark DataFrame.show() for each physical nodes") + var debugShowNodeDf: Boolean = false + // Window skew optimization @ConfigOption(name = "openmldb.window.skew.opt", doc = "Enable window skew optimization or not") var enableWindowSkewOpt: Boolean = false @@ -66,7 +70,6 @@ class OpenmldbBatchConfig extends Serializable { @ConfigOption(name = "openmldb.window.skew.opt.config", doc = "The skew config for window skew optimization") var windowSkewOptConfig: String = "" - // 慢速执行模式 @ConfigOption(name = "openmldb.slowRunCacheDir", doc = """ | Slow run mode cache directory path. If specified, run OpenMLDB plan with slow mode. diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala index a2d8b5dc69b..2af2a5efefa 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala @@ -273,10 +273,13 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN // Set the output to context cache ctx.putPlanResult(root.GetNodeId(), outputSpatkInstance) + if (config.debugShowNodeDf) { + logger.warn(s"Debug and print DataFrame of nodeId: ${root.GetNodeId()}, nodeType: ${root.GetTypeName()}") + outputSpatkInstance.getDf().show() + } outputSpatkInstance } - /** * Run plan slowly by storing and loading each intermediate result from external data path. */ diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala index 2c7141d97db..689981f8fdd 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala @@ -279,9 +279,6 @@ object WindowAggPlan { val skewGroups = config.groupIdxs :+ config.partIdIdx computer.resetGroupKeyComparator(skewGroups) } - if (sqlConfig.print) { - logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}") - } val resIter = if (sqlConfig.enableWindowSkewOpt) { limitInputIter.flatMap(zippedRow => { @@ -341,9 +338,6 @@ object WindowAggPlan { val skewGroups = config.groupIdxs :+ config.partIdIdx computer.resetGroupKeyComparator(skewGroups) } - if (sqlConfig.print) { - logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}") - } val resIter = if (sqlConfig.enableWindowSkewOpt) { limitInputIter.flatMap(row => { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/SparkRowUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/SparkRowUtil.scala index b99aecc3926..b7803693fd2 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/SparkRowUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/SparkRowUtil.scala @@ -16,28 +16,65 @@ package com._4paradigm.openmldb.batch.utils -import com._4paradigm.hybridse.sdk.{HybridSeException, UnsupportedHybridSeException} +import com._4paradigm.hybridse.sdk.HybridSeException import com._4paradigm.openmldb.proto.Type import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{ - BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, - ShortType, StringType, TimestampType -} +import org.apache.spark.sql.types.{BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, + ShortType, StringType, StructType, TimestampType} object SparkRowUtil { - def getLongFromIndex(keyIdx: Int, sparkType: DataType, row: Row): java.lang.Long = { + def rowToString(schema: StructType, row: Row): String = { + val rowStr = new StringBuilder("Spark row: ") + (0 until schema.size).foreach(i => { + if (i == 0) { + rowStr ++= s"${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}" + } else { + rowStr ++= s", ${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}" + } + }) + rowStr.toString() + } + + /** + * Get the string value of the specified column. + * + * @param row + * @param index + * @return + */ + def getColumnStringValue(schema: StructType, row: Row, index: Int): String = { + if (row.isNullAt(index)) { + "null" + } else { + val colType = schema(index).dataType + colType match { + case BooleanType => String.valueOf(row.getBoolean(index)) + case ShortType => String.valueOf(row.getShort(index)) + case DoubleType => String.valueOf(row.getDouble(index)) + case IntegerType => String.valueOf(row.getInt(index)) + case LongType => String.valueOf(row.getLong(index)) + case TimestampType => String.valueOf(row.getTimestamp(index)) + case DateType => String.valueOf(row.getDate(index)) + case StringType => row.getString(index) + case _ => + throw new HybridSeException(s"Unsupported data type: $colType") + } + } + } + + def getLongFromIndex(keyIdx: Int, colType: DataType, row: Row): java.lang.Long = { if (row.isNullAt(keyIdx)) { null } else { - sparkType match { + colType match { case ShortType => row.getShort(keyIdx).toLong case IntegerType => row.getInt(keyIdx).toLong case LongType => row.getLong(keyIdx) case TimestampType => row.getTimestamp(keyIdx).getTime case DateType => row.getDate(keyIdx).getTime case _ => - throw new HybridSeException(s"Illegal window key type: $sparkType") + throw new HybridSeException(s"Illegal window key type: $colType") } } } diff --git a/java/openmldb-batch/src/test/resources/log4j.properties b/java/openmldb-batch/src/test/resources/log4j.properties index f332c949460..51d62679339 100755 --- a/java/openmldb-batch/src/test/resources/log4j.properties +++ b/java/openmldb-batch/src/test/resources/log4j.properties @@ -1,51 +1,9 @@ -### set log levels ### -log4j.rootLogger=stdout,warn,error -# console log +log4j.rootLogger=WARN, stdout + +# Console log log4j.appender.stdout = org.apache.log4j.ConsoleAppender log4j.appender.stdout.Target = System.out -log4j.appender.stdout.Threshold = INFO log4j.appender.stdout.layout = org.apache.log4j.PatternLayout log4j.appender.stdout.Encoding=UTF-8 -log4j.appender.stdout.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n - -#info log -log4j.logger.info=info -log4j.appender.info=org.apache.log4j.DailyRollingFileAppender -log4j.appender.info.DatePattern='_'yyyy-MM-dd'.log' -log4j.appender.info.File=logs/info.log -log4j.appender.info.Append=true -log4j.appender.info.Threshold=INFO -log4j.appender.info.Encoding=UTF-8 -log4j.appender.info.layout=org.apache.log4j.PatternLayout -log4j.appender.info.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n -#debugs log -log4j.logger.debug=debug -log4j.appender.debug=org.apache.log4j.DailyRollingFileAppender -log4j.appender.debug.DatePattern='_'yyyy-MM-dd'.log' -log4j.appender.debug.File=logs/debug.log -log4j.appender.debug.Append=true -log4j.appender.debug.Threshold=DEBUG -log4j.appender.debug.Encoding=UTF-8 -log4j.appender.debug.layout=org.apache.log4j.PatternLayout -log4j.appender.debug.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n -#warn log -log4j.logger.warn=warn -log4j.appender.warn=org.apache.log4j.DailyRollingFileAppender -log4j.appender.warn.DatePattern='_'yyyy-MM-dd'.log' -log4j.appender.warn.File=logs/warn.log -log4j.appender.warn.Append=true -log4j.appender.warn.Threshold=WARN -log4j.appender.warn.Encoding=UTF-8 -log4j.appender.warn.layout=org.apache.log4j.PatternLayout -log4j.appender.warn.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n -#error -log4j.logger.error=error -log4j.appender.error = org.apache.log4j.DailyRollingFileAppender -log4j.appender.error.DatePattern='_'yyyy-MM-dd'.log' -log4j.appender.error.File = logs/error.log -log4j.appender.error.Append = true -log4j.appender.error.Threshold = ERROR -log4j.appender.error.Encoding=UTF-8 -log4j.appender.error.layout = org.apache.log4j.PatternLayout -log4j.appender.error.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n +log4j.appender.stdout.layout.ConversionPattern = %c.%M(%F:%L) - %p: %m%n diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/DataUtil.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/DataUtil.scala new file mode 100644 index 00000000000..7058fd58b2a --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/DataUtil.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end + +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} + +object DataUtil { + + def getStringDf(spark: SparkSession): DataFrame = { + val data = Seq( + Row(1, "abc", 100) + ) + val schema = StructType(List( + StructField("int_col", IntegerType), + StructField("str_col", StringType), + StructField("int_col2", IntegerType) + )) + spark.createDataFrame(spark.sparkContext.makeRDD(data), schema) + } + + def getTestDf(spark: SparkSession): DataFrame = { + val data = Seq( + Row(1, "tom", 100L, 1), + Row(2, "tom", 200L, 2), + Row(3, "tom", 300L, 3), + Row(4, "amy", 400L, 4), + Row(5, "amy", 500L, 5)) + val schema = StructType(List( + StructField("id", IntegerType), + StructField("name", StringType), + StructField("trans_amount", LongType), + StructField("trans_time", IntegerType))) + spark.createDataFrame(spark.sparkContext.makeRDD(data), schema) + } + +} diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeProject.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeProject.scala new file mode 100644 index 00000000000..df1c94fab05 --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeProject.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end.unsafe + +import com._4paradigm.openmldb.batch.SparkTestSuite +import com._4paradigm.openmldb.batch.api.OpenmldbSession +import com._4paradigm.openmldb.batch.end2end.DataUtil +import com._4paradigm.openmldb.batch.utils.SparkUtil + +class TestUnsafeProject extends SparkTestSuite { + + override def customizedBefore(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", true) + } + + test("Test unsafe project") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val df = DataUtil.getStringDf(spark) + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText = "SELECT int_col, int_col2 + 1000 FROM t1" + + val outputDf = sess.sql(sqlText) + val sparksqlOutputDf = sess.sparksql(sqlText) + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + + } + + override def customizedAfter(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", false) + } + +} diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindow.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindow.scala new file mode 100644 index 00000000000..6c99bb9ea6b --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeWindow.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end.unsafe + +import com._4paradigm.openmldb.batch.SparkTestSuite +import com._4paradigm.openmldb.batch.api.OpenmldbSession +import com._4paradigm.openmldb.batch.end2end.DataUtil +import com._4paradigm.openmldb.batch.utils.SparkUtil + +class TestUnsafeWindow extends SparkTestSuite { + + override def customizedBefore(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", true) + } + + test("Test unsafe window") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val df = DataUtil.getTestDf(spark) + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText =""" + | SELECT id, sum(trans_amount) OVER w AS w_sum_amount FROM t1 + | WINDOW w AS ( + | PARTITION BY id + | ORDER BY trans_time + | ROWS BETWEEN 10 PRECEDING AND CURRENT ROW); + """.stripMargin + + val outputDf = sess.sql(sqlText) + val sparksqlOutputDf = sess.sparksql(sqlText) + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + } + + override def customizedAfter(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", false) + } + +} diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestSparkRowUtil.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestSparkRowUtil.scala index c9ca2b94b79..1facce4a4ac 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestSparkRowUtil.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestSparkRowUtil.scala @@ -19,7 +19,8 @@ package com._4paradigm.openmldb.batch.utils import com._4paradigm.hybridse.sdk.HybridSeException import com._4paradigm.openmldb.batch.utils.SparkRowUtil.getLongFromIndex import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType, TimestampType} +import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType, StringType, StructField, + StructType, TimestampType} import org.scalatest.FunSuite import java.sql.{Date, Timestamp} @@ -57,4 +58,17 @@ class TestSparkRowUtil extends FunSuite { assert(getLongFromIndex(0, DateType, nullRow) == null) } + test("Test rowToString") { + val row = Row(1, "tom", 100, 1) + val schema = StructType(List( + StructField("id", IntegerType), + StructField("name", StringType), + StructField("trans_amount", IntegerType), + StructField("trans_time", IntegerType))) + + val outputStr = SparkRowUtil.rowToString(schema, row) + val expectStr = "Spark row: IntegerType: 1, StringType: tom, IntegerType: 100, IntegerType: 1" + assert(outputStr.equals(expectStr)) + } + }