diff --git a/hybridse/src/codec/fe_row_codec.cc b/hybridse/src/codec/fe_row_codec.cc index aa56f8bf97b..1cfcbdf3194 100644 --- a/hybridse/src/codec/fe_row_codec.cc +++ b/hybridse/src/codec/fe_row_codec.cc @@ -72,7 +72,12 @@ RowBuilder::RowBuilder(const Schema& schema) for (int idx = 0; idx < schema.size(); idx++) { const ::hybridse::type::ColumnDef& column = schema.Get(idx); if (column.type() == ::hybridse::type::kVarchar) { - offset_vec_.push_back(str_field_cnt_); + if (FLAGS_enable_spark_unsaferow_format) { + offset_vec_.push_back(str_field_start_offset_); + str_field_start_offset_ += 8; + } else { + offset_vec_.push_back(str_field_cnt_); + } str_field_cnt_++; } else { auto TYPE_SIZE_MAP = GetTypeSizeMap(); @@ -106,7 +111,11 @@ bool RowBuilder::SetBuffer(int8_t* buf, uint32_t size) { memset(buf_ + HEADER_LENGTH, 0, bitmap_size); cnt_ = 0; str_addr_length_ = GetAddrLength(size); - str_offset_ = str_field_start_offset_ + str_addr_length_ * str_field_cnt_; + if (FLAGS_enable_spark_unsaferow_format) { + str_offset_ = str_field_start_offset_; + } else { + str_offset_ = str_field_start_offset_ + str_addr_length_ * str_field_cnt_; + } return true; } @@ -176,11 +185,17 @@ void FillNullStringOffset(int8_t* buf, uint32_t start, uint32_t addr_length, bool RowBuilder::AppendNULL() { int8_t* ptr = buf_ + HEADER_LENGTH + (cnt_ >> 3); *(reinterpret_cast(ptr)) |= 1 << (cnt_ & 0x07); - const ::hybridse::type::ColumnDef& column = schema_.Get(cnt_); - if (column.type() == ::hybridse::type::kVarchar) { - FillNullStringOffset(buf_, str_field_start_offset_, str_addr_length_, - offset_vec_[cnt_], str_offset_); + + if (FLAGS_enable_spark_unsaferow_format) { + // Do not fill null for UnsafeRowOpt + } else { + const ::hybridse::type::ColumnDef& column = schema_.Get(cnt_); + if (column.type() == ::hybridse::type::kVarchar) { + FillNullStringOffset(buf_, str_field_start_offset_, str_addr_length_, + offset_vec_[cnt_], str_offset_); + } } + cnt_++; return true; } @@ -257,22 +272,31 @@ bool RowBuilder::AppendDouble(double val) { bool RowBuilder::AppendString(const char* val, uint32_t length) { if (val == NULL || !Check(::hybridse::type::kVarchar)) return false; if (str_offset_ + length > size_) return false; - int8_t* ptr = - buf_ + str_field_start_offset_ + str_addr_length_ * offset_vec_[cnt_]; - if (str_addr_length_ == 1) { - *(reinterpret_cast(ptr)) = (uint8_t)str_offset_; - } else if (str_addr_length_ == 2) { - *(reinterpret_cast(ptr)) = (uint16_t)str_offset_; - } else if (str_addr_length_ == 3) { - *(reinterpret_cast(ptr)) = str_offset_ >> 16; - *(reinterpret_cast(ptr + 1)) = (str_offset_ & 0xFF00) >> 8; - *(reinterpret_cast(ptr + 2)) = str_offset_ & 0x00FF; + + if (FLAGS_enable_spark_unsaferow_format) { + int8_t* ptr = buf_ + offset_vec_[cnt_]; + *(reinterpret_cast(ptr)) = length; + *(reinterpret_cast(ptr + 4)) = str_offset_ - HEADER_LENGTH; } else { - *(reinterpret_cast(ptr)) = str_offset_; + int8_t* ptr = + buf_ + str_field_start_offset_ + str_addr_length_ * offset_vec_[cnt_]; + if (str_addr_length_ == 1) { + *(reinterpret_cast(ptr)) = (uint8_t)str_offset_; + } else if (str_addr_length_ == 2) { + *(reinterpret_cast(ptr)) = (uint16_t)str_offset_; + } else if (str_addr_length_ == 3) { + *(reinterpret_cast(ptr)) = str_offset_ >> 16; + *(reinterpret_cast(ptr + 1)) = (str_offset_ & 0xFF00) >> 8; + *(reinterpret_cast(ptr + 2)) = str_offset_ & 0x00FF; + } else { + *(reinterpret_cast(ptr)) = str_offset_; + } } + if (length != 0) { memcpy(reinterpret_cast(buf_ + str_offset_), val, length); } + str_offset_ += length; cnt_++; return true; @@ -330,7 +354,12 @@ bool RowView::Init() { for (int idx = 0; idx < schema_.size(); idx++) { const ::hybridse::type::ColumnDef& column = schema_.Get(idx); if (column.type() == ::hybridse::type::kVarchar) { - offset_vec_.push_back(string_field_cnt_); + if (FLAGS_enable_spark_unsaferow_format) { + offset_vec_.push_back(offset); + offset += 8; + } else { + offset_vec_.push_back(string_field_cnt_); + } string_field_cnt_++; } else { auto TYPE_SIZE_MAP = GetTypeSizeMap(); @@ -449,6 +478,7 @@ std::string RowView::GetStringUnsafe(uint32_t idx) { } const char* val; uint32_t length; + v1::GetStrFieldUnsafe(row_, idx, field_offset, next_str_field_offset, str_field_start_offset_, str_addr_length_, &val, &length); @@ -846,6 +876,7 @@ int32_t RowView::GetValue(const int8_t* row, uint32_t idx, const char** val, if (offset_vec_.at(idx) < string_field_cnt_ - 1) { next_str_field_offset = field_offset + 1; } + return v1::GetStrFieldUnsafe(row, idx, field_offset, next_str_field_offset, str_field_start_offset_, GetAddrLength(size), val, length); @@ -887,8 +918,14 @@ RowFormat::RowFormat(const hybridse::codec::Schema* schema) for (int32_t i = 0; i < schema_->size(); i++) { const ::hybridse::type::ColumnDef& column = schema_->Get(i); if (column.type() == ::hybridse::type::kVarchar) { - infos_.push_back( - ColInfo(column.name(), column.type(), i, string_field_cnt)); + if (FLAGS_enable_spark_unsaferow_format) { + infos_.push_back( + ColInfo(column.name(), column.type(), i, offset)); + } else { + infos_.push_back( + ColInfo(column.name(), column.type(), i, string_field_cnt)); + } + infos_dict_[column.name()] = i; next_str_pos_.insert( std::make_pair(string_field_cnt, string_field_cnt)); @@ -939,26 +976,24 @@ bool RowFormat::GetStringColumnInfo(size_t idx, StringColInfo* res) const { auto ty = base_col_info.type; uint32_t col_idx = base_col_info.idx; uint32_t offset = base_col_info.offset; - uint32_t next_offset; + uint32_t next_offset = -1; auto nit = next_str_pos_.find(offset); if (nit != next_str_pos_.end()) { next_offset = nit->second; } else { - LOG(WARNING) << "fail to get string field next offset"; - return false; + if (FLAGS_enable_spark_unsaferow_format) { + // Do not need to get next offset for UnsafeRowOpt + } else { + LOG(WARNING) << "fail to get string field next offset"; + return false; + } } DLOG(INFO) << "get string with offset " << offset << " next offset " << next_offset << " str_field_start_offset " << str_field_start_offset_ << " for col " << base_col_info.name; - if (FLAGS_enable_spark_unsaferow_format) { - // Notice that we pass the nullbitmap size as str_field_start_offset - *res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset, - BitMapSize(schema_->size())); - } else { - *res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset, - str_field_start_offset_); - } + *res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset, + str_field_start_offset_); return true; } diff --git a/hybridse/src/codec/type_codec.cc b/hybridse/src/codec/type_codec.cc index 576fcad5a67..3b0e0c176f0 100644 --- a/hybridse/src/codec/type_codec.cc +++ b/hybridse/src/codec/type_codec.cc @@ -84,14 +84,11 @@ int32_t GetStrFieldUnsafe(const int8_t* row, uint32_t col_idx, // Support Spark UnsafeRow format if (FLAGS_enable_spark_unsaferow_format) { - // For UnsafeRow opt, str_start_offset is the nullbitmap size - const uint32_t bitmap_size = str_start_offset; - const int8_t* row_with_col_offset = row + HEADER_LENGTH + bitmap_size + col_idx * 8; + // Notice that for UnsafeRowOpt field_offset should be the actual offset of string column - // For Spark UnsafeRow, the first 32 bits is for length and the last - // 32 bits is for offset. - *size = *(reinterpret_cast(row_with_col_offset)); - uint32_t str_value_offset = *(reinterpret_cast(row_with_col_offset + 4)) + HEADER_LENGTH; + // For Spark UnsafeRow, the first 32 bits is for length and the last 32 bits is for offset. + *size = *(reinterpret_cast(row + field_offset)); + uint32_t str_value_offset = *(reinterpret_cast(row + field_offset + 4)) + HEADER_LENGTH; *data = reinterpret_cast(row + str_value_offset); return 0; diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala index c34ab330103..8c96d7a9aae 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala @@ -122,6 +122,18 @@ class OpenmldbBatchConfig extends Serializable { @ConfigOption(name = "openmldb.unsaferow.opt", doc = "Enable UnsafeRow optimization or not") var enableUnsafeRowOptimization = false + @ConfigOption(name = "openmldb.opt.unsaferow.project", doc = "Enable UnsafeRow optimization for project") + var enableUnsafeRowOptForProject = false + + @ConfigOption(name = "openmldb.opt.unsaferow.window", doc = "Enable UnsafeRow optimization for window") + var enableUnsafeRowOptForWindow = false + + @ConfigOption(name = "openmldb.opt.unsaferow.groupby", doc = "Enable UnsafeRow optimization for groupby") + var enableUnsafeRowOptForGroupby = false + + @ConfigOption(name = "openmldb.opt.unsaferow.join", doc = "Enable UnsafeRow optimization for join") + var enableUnsafeRowOptForJoin = false + // Switch for disable OpenMLDB @ConfigOption(name = "openmldb.disable", doc = "Disable OpenMLDB optimization or not") var disableOpenmldb = false diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/RowProjectPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/RowProjectPlan.scala index 33588da74ab..66f2b61a4da 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/RowProjectPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/RowProjectPlan.scala @@ -71,7 +71,7 @@ object RowProjectPlan { val openmldbJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath - val outputDf = if (ctx.getConf.enableUnsafeRowOptimization) { // Use UnsafeRow optimization + val outputDf = if (ctx.getConf.enableUnsafeRowOptForProject) { // Use UnsafeRow optimization val outputInternalRowRdd = inputDf.queryExecution.toRdd.mapPartitions(partitionIter => { val tag = projectConfig.moduleTag diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala index 8e2d476f6c7..b3d6cdcb771 100755 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala @@ -54,7 +54,7 @@ object WindowAggPlan { // Check if we should keep the index column val isKeepIndexColumn = SparkInstance.keepIndexColumn(ctx, physicalNode.GetNodeId()) // Check if use UnsafeRow optimizaiton or not - val isUnsafeRowOptimization = ctx.getConf.enableUnsafeRowOptimization + val isUnsafeRowOptimization = ctx.getConf.enableUnsafeRowOptForWindow // Check if we should keep the index column val isWindowSkewOptimization = ctx.getConf.enableWindowSkewOpt diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeGroupby.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeGroupby.scala new file mode 100644 index 00000000000..5c385243434 --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeGroupby.scala @@ -0,0 +1,53 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end.unsafe + +import com._4paradigm.openmldb.batch.SparkTestSuite +import com._4paradigm.openmldb.batch.api.OpenmldbSession +import com._4paradigm.openmldb.batch.end2end.DataUtil +import com._4paradigm.openmldb.batch.utils.SparkUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class TestUnsafeGroupby extends SparkTestSuite { + + override def customizedBefore(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", true) + } + + test("Test unsafe groupby") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val df = DataUtil.getTestDf(spark) + sess.registerTable("t1", df) + df.createOrReplaceTempView("t1") + + val sqlText = "SELECT max(id) AS max_id, sum(trans_amount) AS sum_amount FROM t1 GROUP BY name" + + val outputDf = sess.sql(sqlText) + val sparksqlOutputDf = sess.sparksql(sqlText) + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + } + + override def customizedAfter(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", false) + } + +} diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeJoin.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeJoin.scala new file mode 100644 index 00000000000..94c1bf6219f --- /dev/null +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/unsafe/TestUnsafeJoin.scala @@ -0,0 +1,53 @@ +/* + * Copyright 2021 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com._4paradigm.openmldb.batch.end2end.unsafe + +import com._4paradigm.openmldb.batch.SparkTestSuite +import com._4paradigm.openmldb.batch.api.OpenmldbSession +import com._4paradigm.openmldb.batch.end2end.DataUtil +import com._4paradigm.openmldb.batch.utils.SparkUtil + +class TestUnsafeJoin extends SparkTestSuite { + + override def customizedBefore(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", true) + } + + test("Test unsafe join") { + val spark = getSparkSession + val sess = new OpenmldbSession(spark) + + val df = DataUtil.getTestDf(spark) + sess.registerTable("t1", df) + sess.registerTable("t2", df) + df.createOrReplaceTempView("t1") + df.createOrReplaceTempView("t2") + + val sqlText = "SELECT t1.id as t1_id, t2.id as t2_id, t1.name FROM t1 LEFT JOIN t2 ON t1.id = t2.id" + + val outputDf = sess.sql(sqlText) + val sparksqlOutputDf = sess.sparksql(sqlText) + assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false)) + } + + override def customizedAfter(): Unit = { + val spark = getSparkSession + spark.conf.set("spark.openmldb.unsaferow.opt", false) + } + +}