Skip to content

Commit

Permalink
feat: fix unsaferow for window (#1298)
Browse files Browse the repository at this point in the history
* Support config of debugShowNodeDf for openmldb-batch

* Add rowToString in spark row util and add unit tests

* Update offset for string and non-string columns in row codec

* Add unit tests for unsafe row opt

* Add data util for openmldb-batch unit tests
  • Loading branch information
tobegit3hub authored Feb 24, 2022
1 parent 5cc52af commit 2e7fb76
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 64 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ configure: thirdparty-fast

openmldb-clean:
rm -rf "$(OPENMLDB_BUILD_DIR)"
@cd java && ./mvnw clean

THIRD_PARTY_BUILD_DIR ?= $(MAKEFILE_DIR)/.deps
THIRD_PARTY_SRC_DIR ?= $(MAKEFILE_DIR)/thirdsrc
Expand Down
5 changes: 5 additions & 0 deletions hybridse/src/codec/fe_row_codec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,11 @@ RowFormat::RowFormat(const hybridse::codec::Schema* schema)
next_str_pos_.insert(
std::make_pair(string_field_cnt, string_field_cnt));
string_field_cnt += 1;

if (FLAGS_enable_spark_unsaferow_format) {
// For UnsafeRowOpt, the offset should be added for string and non-string columns
offset += 8;
}
} else {
auto TYPE_SIZE_MAP = codec::GetTypeSizeMap();
auto it = TYPE_SIZE_MAP.find(column.type());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OpenmldbBatchConfig extends Serializable {
@ConfigOption(name = "spark.sql.session.timeZone")
var timeZone = "Asia/Shanghai"

// test mode 用于测试的时候验证相关问题
// test mode
@ConfigOption(name = "openmldb.test.tiny", doc = "控制读取表的数据条数,默认读全量数据")
var tinyData: Long = -1

Expand All @@ -42,6 +42,10 @@ class OpenmldbBatchConfig extends Serializable {
@ConfigOption(name = "openmldb.test.print", doc = "执行过程中允许打印数据")
var print: Boolean = false

// Debug options
@ConfigOption(name = "openmldb.debug.show_node_df", doc = "Use Spark DataFrame.show() for each physical nodes")
var debugShowNodeDf: Boolean = false

// Window skew optimization
@ConfigOption(name = "openmldb.window.skew.opt", doc = "Enable window skew optimization or not")
var enableWindowSkewOpt: Boolean = false
Expand All @@ -66,7 +70,6 @@ class OpenmldbBatchConfig extends Serializable {
@ConfigOption(name = "openmldb.window.skew.opt.config", doc = "The skew config for window skew optimization")
var windowSkewOptConfig: String = ""

// 慢速执行模式
@ConfigOption(name = "openmldb.slowRunCacheDir", doc =
"""
| Slow run mode cache directory path. If specified, run OpenMLDB plan with slow mode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,13 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
// Set the output to context cache
ctx.putPlanResult(root.GetNodeId(), outputSpatkInstance)

if (config.debugShowNodeDf) {
logger.warn(s"Debug and print DataFrame of nodeId: ${root.GetNodeId()}, nodeType: ${root.GetTypeName()}")
outputSpatkInstance.getDf().show()
}
outputSpatkInstance
}


/**
* Run plan slowly by storing and loading each intermediate result from external data path.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,6 @@ object WindowAggPlan {
val skewGroups = config.groupIdxs :+ config.partIdIdx
computer.resetGroupKeyComparator(skewGroups)
}
if (sqlConfig.print) {
logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}")
}

val resIter = if (sqlConfig.enableWindowSkewOpt) {
limitInputIter.flatMap(zippedRow => {
Expand Down Expand Up @@ -341,9 +338,6 @@ object WindowAggPlan {
val skewGroups = config.groupIdxs :+ config.partIdIdx
computer.resetGroupKeyComparator(skewGroups)
}
if (sqlConfig.print) {
logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}")
}

val resIter = if (sqlConfig.enableWindowSkewOpt) {
limitInputIter.flatMap(row => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,65 @@

package com._4paradigm.openmldb.batch.utils

import com._4paradigm.hybridse.sdk.{HybridSeException, UnsupportedHybridSeException}
import com._4paradigm.hybridse.sdk.HybridSeException
import com._4paradigm.openmldb.proto.Type
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType,
ShortType, StringType, TimestampType
}
import org.apache.spark.sql.types.{BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType,
ShortType, StringType, StructType, TimestampType}

object SparkRowUtil {

def getLongFromIndex(keyIdx: Int, sparkType: DataType, row: Row): java.lang.Long = {
def rowToString(schema: StructType, row: Row): String = {
val rowStr = new StringBuilder("Spark row: ")
(0 until schema.size).foreach(i => {
if (i == 0) {
rowStr ++= s"${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}"
} else {
rowStr ++= s", ${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}"
}
})
rowStr.toString()
}

/**
* Get the string value of the specified column.
*
* @param row
* @param index
* @return
*/
def getColumnStringValue(schema: StructType, row: Row, index: Int): String = {
if (row.isNullAt(index)) {
"null"
} else {
val colType = schema(index).dataType
colType match {
case BooleanType => String.valueOf(row.getBoolean(index))
case ShortType => String.valueOf(row.getShort(index))
case DoubleType => String.valueOf(row.getDouble(index))
case IntegerType => String.valueOf(row.getInt(index))
case LongType => String.valueOf(row.getLong(index))
case TimestampType => String.valueOf(row.getTimestamp(index))
case DateType => String.valueOf(row.getDate(index))
case StringType => row.getString(index)
case _ =>
throw new HybridSeException(s"Unsupported data type: $colType")
}
}
}

def getLongFromIndex(keyIdx: Int, colType: DataType, row: Row): java.lang.Long = {
if (row.isNullAt(keyIdx)) {
null
} else {
sparkType match {
colType match {
case ShortType => row.getShort(keyIdx).toLong
case IntegerType => row.getInt(keyIdx).toLong
case LongType => row.getLong(keyIdx)
case TimestampType => row.getTimestamp(keyIdx).getTime
case DateType => row.getDate(keyIdx).getTime
case _ =>
throw new HybridSeException(s"Illegal window key type: $sparkType")
throw new HybridSeException(s"Illegal window key type: $colType")
}
}
}
Expand Down
50 changes: 4 additions & 46 deletions java/openmldb-batch/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
@@ -1,51 +1,9 @@
### set log levels ###
log4j.rootLogger=stdout,warn,error

# console log
log4j.rootLogger=WARN, stdout

# Console log
log4j.appender.stdout = org.apache.log4j.ConsoleAppender
log4j.appender.stdout.Target = System.out
log4j.appender.stdout.Threshold = INFO
log4j.appender.stdout.layout = org.apache.log4j.PatternLayout
log4j.appender.stdout.Encoding=UTF-8
log4j.appender.stdout.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n

#info log
log4j.logger.info=info
log4j.appender.info=org.apache.log4j.DailyRollingFileAppender
log4j.appender.info.DatePattern='_'yyyy-MM-dd'.log'
log4j.appender.info.File=logs/info.log
log4j.appender.info.Append=true
log4j.appender.info.Threshold=INFO
log4j.appender.info.Encoding=UTF-8
log4j.appender.info.layout=org.apache.log4j.PatternLayout
log4j.appender.info.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
#debugs log
log4j.logger.debug=debug
log4j.appender.debug=org.apache.log4j.DailyRollingFileAppender
log4j.appender.debug.DatePattern='_'yyyy-MM-dd'.log'
log4j.appender.debug.File=logs/debug.log
log4j.appender.debug.Append=true
log4j.appender.debug.Threshold=DEBUG
log4j.appender.debug.Encoding=UTF-8
log4j.appender.debug.layout=org.apache.log4j.PatternLayout
log4j.appender.debug.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
#warn log
log4j.logger.warn=warn
log4j.appender.warn=org.apache.log4j.DailyRollingFileAppender
log4j.appender.warn.DatePattern='_'yyyy-MM-dd'.log'
log4j.appender.warn.File=logs/warn.log
log4j.appender.warn.Append=true
log4j.appender.warn.Threshold=WARN
log4j.appender.warn.Encoding=UTF-8
log4j.appender.warn.layout=org.apache.log4j.PatternLayout
log4j.appender.warn.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
#error
log4j.logger.error=error
log4j.appender.error = org.apache.log4j.DailyRollingFileAppender
log4j.appender.error.DatePattern='_'yyyy-MM-dd'.log'
log4j.appender.error.File = logs/error.log
log4j.appender.error.Append = true
log4j.appender.error.Threshold = ERROR
log4j.appender.error.Encoding=UTF-8
log4j.appender.error.layout = org.apache.log4j.PatternLayout
log4j.appender.error.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
log4j.appender.stdout.layout.ConversionPattern = %c.%M(%F:%L) - %p: %m%n
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.end2end

import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object DataUtil {

def getStringDf(spark: SparkSession): DataFrame = {
val data = Seq(
Row(1, "abc", 100)
)
val schema = StructType(List(
StructField("int_col", IntegerType),
StructField("str_col", StringType),
StructField("int_col2", IntegerType)
))
spark.createDataFrame(spark.sparkContext.makeRDD(data), schema)
}

def getTestDf(spark: SparkSession): DataFrame = {
val data = Seq(
Row(1, "tom", 100L, 1),
Row(2, "tom", 200L, 2),
Row(3, "tom", 300L, 3),
Row(4, "amy", 400L, 4),
Row(5, "amy", 500L, 5))
val schema = StructType(List(
StructField("id", IntegerType),
StructField("name", StringType),
StructField("trans_amount", LongType),
StructField("trans_time", IntegerType)))
spark.createDataFrame(spark.sparkContext.makeRDD(data), schema)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.end2end.unsafe

import com._4paradigm.openmldb.batch.SparkTestSuite
import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.end2end.DataUtil
import com._4paradigm.openmldb.batch.utils.SparkUtil

class TestUnsafeProject extends SparkTestSuite {

override def customizedBefore(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", true)
}

test("Test unsafe project") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val df = DataUtil.getStringDf(spark)
sess.registerTable("t1", df)
df.createOrReplaceTempView("t1")

val sqlText = "SELECT int_col, int_col2 + 1000 FROM t1"

val outputDf = sess.sql(sqlText)
val sparksqlOutputDf = sess.sparksql(sqlText)
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))

}

override def customizedAfter(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", false)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.end2end.unsafe

import com._4paradigm.openmldb.batch.SparkTestSuite
import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.end2end.DataUtil
import com._4paradigm.openmldb.batch.utils.SparkUtil

class TestUnsafeWindow extends SparkTestSuite {

override def customizedBefore(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", true)
}

test("Test unsafe window") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val df = DataUtil.getTestDf(spark)
sess.registerTable("t1", df)
df.createOrReplaceTempView("t1")

val sqlText ="""
| SELECT id, sum(trans_amount) OVER w AS w_sum_amount FROM t1
| WINDOW w AS (
| PARTITION BY id
| ORDER BY trans_time
| ROWS BETWEEN 10 PRECEDING AND CURRENT ROW);
""".stripMargin

val outputDf = sess.sql(sqlText)
val sparksqlOutputDf = sess.sparksql(sqlText)
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
}

override def customizedAfter(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", false)
}

}
Loading

0 comments on commit 2e7fb76

Please sign in to comment.