Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support unsaferow format for codec #1362

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 66 additions & 31 deletions hybridse/src/codec/fe_row_codec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ RowBuilder::RowBuilder(const Schema& schema)
for (int idx = 0; idx < schema.size(); idx++) {
const ::hybridse::type::ColumnDef& column = schema.Get(idx);
if (column.type() == ::hybridse::type::kVarchar) {
offset_vec_.push_back(str_field_cnt_);
if (FLAGS_enable_spark_unsaferow_format) {
offset_vec_.push_back(str_field_start_offset_);
str_field_start_offset_ += 8;
} else {
offset_vec_.push_back(str_field_cnt_);
}
str_field_cnt_++;
} else {
auto TYPE_SIZE_MAP = GetTypeSizeMap();
Expand Down Expand Up @@ -106,7 +111,11 @@ bool RowBuilder::SetBuffer(int8_t* buf, uint32_t size) {
memset(buf_ + HEADER_LENGTH, 0, bitmap_size);
cnt_ = 0;
str_addr_length_ = GetAddrLength(size);
str_offset_ = str_field_start_offset_ + str_addr_length_ * str_field_cnt_;
if (FLAGS_enable_spark_unsaferow_format) {
str_offset_ = str_field_start_offset_;
} else {
str_offset_ = str_field_start_offset_ + str_addr_length_ * str_field_cnt_;
}
return true;
}

Expand Down Expand Up @@ -176,11 +185,17 @@ void FillNullStringOffset(int8_t* buf, uint32_t start, uint32_t addr_length,
bool RowBuilder::AppendNULL() {
int8_t* ptr = buf_ + HEADER_LENGTH + (cnt_ >> 3);
*(reinterpret_cast<uint8_t*>(ptr)) |= 1 << (cnt_ & 0x07);
const ::hybridse::type::ColumnDef& column = schema_.Get(cnt_);
if (column.type() == ::hybridse::type::kVarchar) {
FillNullStringOffset(buf_, str_field_start_offset_, str_addr_length_,
offset_vec_[cnt_], str_offset_);

if (FLAGS_enable_spark_unsaferow_format) {
// Do not fill null for UnsafeRowOpt
} else {
const ::hybridse::type::ColumnDef& column = schema_.Get(cnt_);
if (column.type() == ::hybridse::type::kVarchar) {
FillNullStringOffset(buf_, str_field_start_offset_, str_addr_length_,
offset_vec_[cnt_], str_offset_);
}
}

cnt_++;
return true;
}
Expand Down Expand Up @@ -257,22 +272,31 @@ bool RowBuilder::AppendDouble(double val) {
bool RowBuilder::AppendString(const char* val, uint32_t length) {
if (val == NULL || !Check(::hybridse::type::kVarchar)) return false;
if (str_offset_ + length > size_) return false;
int8_t* ptr =
buf_ + str_field_start_offset_ + str_addr_length_ * offset_vec_[cnt_];
if (str_addr_length_ == 1) {
*(reinterpret_cast<uint8_t*>(ptr)) = (uint8_t)str_offset_;
} else if (str_addr_length_ == 2) {
*(reinterpret_cast<uint16_t*>(ptr)) = (uint16_t)str_offset_;
} else if (str_addr_length_ == 3) {
*(reinterpret_cast<uint8_t*>(ptr)) = str_offset_ >> 16;
*(reinterpret_cast<uint8_t*>(ptr + 1)) = (str_offset_ & 0xFF00) >> 8;
*(reinterpret_cast<uint8_t*>(ptr + 2)) = str_offset_ & 0x00FF;

if (FLAGS_enable_spark_unsaferow_format) {
int8_t* ptr = buf_ + offset_vec_[cnt_];
*(reinterpret_cast<uint32_t*>(ptr)) = length;
*(reinterpret_cast<uint32_t*>(ptr + 4)) = str_offset_ - HEADER_LENGTH;
} else {
*(reinterpret_cast<uint32_t*>(ptr)) = str_offset_;
int8_t* ptr =
buf_ + str_field_start_offset_ + str_addr_length_ * offset_vec_[cnt_];
if (str_addr_length_ == 1) {
*(reinterpret_cast<uint8_t*>(ptr)) = (uint8_t)str_offset_;
} else if (str_addr_length_ == 2) {
*(reinterpret_cast<uint16_t*>(ptr)) = (uint16_t)str_offset_;
} else if (str_addr_length_ == 3) {
*(reinterpret_cast<uint8_t*>(ptr)) = str_offset_ >> 16;
*(reinterpret_cast<uint8_t*>(ptr + 1)) = (str_offset_ & 0xFF00) >> 8;
*(reinterpret_cast<uint8_t*>(ptr + 2)) = str_offset_ & 0x00FF;
} else {
*(reinterpret_cast<uint32_t*>(ptr)) = str_offset_;
}
}

if (length != 0) {
memcpy(reinterpret_cast<char*>(buf_ + str_offset_), val, length);
}

str_offset_ += length;
cnt_++;
return true;
Expand Down Expand Up @@ -330,7 +354,12 @@ bool RowView::Init() {
for (int idx = 0; idx < schema_.size(); idx++) {
const ::hybridse::type::ColumnDef& column = schema_.Get(idx);
if (column.type() == ::hybridse::type::kVarchar) {
offset_vec_.push_back(string_field_cnt_);
if (FLAGS_enable_spark_unsaferow_format) {
offset_vec_.push_back(offset);
offset += 8;
} else {
offset_vec_.push_back(string_field_cnt_);
}
string_field_cnt_++;
} else {
auto TYPE_SIZE_MAP = GetTypeSizeMap();
Expand Down Expand Up @@ -449,6 +478,7 @@ std::string RowView::GetStringUnsafe(uint32_t idx) {
}
const char* val;
uint32_t length;

v1::GetStrFieldUnsafe(row_, idx, field_offset, next_str_field_offset,
str_field_start_offset_, str_addr_length_, &val,
&length);
Expand Down Expand Up @@ -846,6 +876,7 @@ int32_t RowView::GetValue(const int8_t* row, uint32_t idx, const char** val,
if (offset_vec_.at(idx) < string_field_cnt_ - 1) {
next_str_field_offset = field_offset + 1;
}

return v1::GetStrFieldUnsafe(row, idx, field_offset, next_str_field_offset,
str_field_start_offset_, GetAddrLength(size),
val, length);
Expand Down Expand Up @@ -887,8 +918,14 @@ RowFormat::RowFormat(const hybridse::codec::Schema* schema)
for (int32_t i = 0; i < schema_->size(); i++) {
const ::hybridse::type::ColumnDef& column = schema_->Get(i);
if (column.type() == ::hybridse::type::kVarchar) {
infos_.push_back(
ColInfo(column.name(), column.type(), i, string_field_cnt));
if (FLAGS_enable_spark_unsaferow_format) {
infos_.push_back(
ColInfo(column.name(), column.type(), i, offset));
} else {
infos_.push_back(
ColInfo(column.name(), column.type(), i, string_field_cnt));
}

infos_dict_[column.name()] = i;
next_str_pos_.insert(
std::make_pair(string_field_cnt, string_field_cnt));
Expand Down Expand Up @@ -939,26 +976,24 @@ bool RowFormat::GetStringColumnInfo(size_t idx, StringColInfo* res) const {
auto ty = base_col_info.type;
uint32_t col_idx = base_col_info.idx;
uint32_t offset = base_col_info.offset;
uint32_t next_offset;
uint32_t next_offset = -1;
auto nit = next_str_pos_.find(offset);
if (nit != next_str_pos_.end()) {
next_offset = nit->second;
} else {
LOG(WARNING) << "fail to get string field next offset";
return false;
if (FLAGS_enable_spark_unsaferow_format) {
// Do not need to get next offset for UnsafeRowOpt
} else {
LOG(WARNING) << "fail to get string field next offset";
return false;
}
}
DLOG(INFO) << "get string with offset " << offset << " next offset "
<< next_offset << " str_field_start_offset "
<< str_field_start_offset_ << " for col " << base_col_info.name;

if (FLAGS_enable_spark_unsaferow_format) {
// Notice that we pass the nullbitmap size as str_field_start_offset
*res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset,
BitMapSize(schema_->size()));
} else {
*res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset,
str_field_start_offset_);
}
*res = StringColInfo(base_col_info.name, ty, col_idx, offset, next_offset,
str_field_start_offset_);
return true;
}

Expand Down
11 changes: 4 additions & 7 deletions hybridse/src/codec/type_codec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,11 @@ int32_t GetStrFieldUnsafe(const int8_t* row, uint32_t col_idx,

// Support Spark UnsafeRow format
if (FLAGS_enable_spark_unsaferow_format) {
// For UnsafeRow opt, str_start_offset is the nullbitmap size
const uint32_t bitmap_size = str_start_offset;
const int8_t* row_with_col_offset = row + HEADER_LENGTH + bitmap_size + col_idx * 8;
// Notice that for UnsafeRowOpt field_offset should be the actual offset of string column

// For Spark UnsafeRow, the first 32 bits is for length and the last
// 32 bits is for offset.
*size = *(reinterpret_cast<const uint32_t*>(row_with_col_offset));
uint32_t str_value_offset = *(reinterpret_cast<const uint32_t*>(row_with_col_offset + 4)) + HEADER_LENGTH;
// For Spark UnsafeRow, the first 32 bits is for length and the last 32 bits is for offset.
*size = *(reinterpret_cast<const uint32_t*>(row + field_offset));
uint32_t str_value_offset = *(reinterpret_cast<const uint32_t*>(row + field_offset + 4)) + HEADER_LENGTH;
*data = reinterpret_cast<const char*>(row + str_value_offset);

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ class OpenmldbBatchConfig extends Serializable {
@ConfigOption(name = "openmldb.unsaferow.opt", doc = "Enable UnsafeRow optimization or not")
var enableUnsafeRowOptimization = false

@ConfigOption(name = "openmldb.opt.unsaferow.project", doc = "Enable UnsafeRow optimization for project")
var enableUnsafeRowOptForProject = false

@ConfigOption(name = "openmldb.opt.unsaferow.window", doc = "Enable UnsafeRow optimization for window")
var enableUnsafeRowOptForWindow = false

@ConfigOption(name = "openmldb.opt.unsaferow.groupby", doc = "Enable UnsafeRow optimization for groupby")
var enableUnsafeRowOptForGroupby = false

@ConfigOption(name = "openmldb.opt.unsaferow.join", doc = "Enable UnsafeRow optimization for join")
var enableUnsafeRowOptForJoin = false

// Switch for disable OpenMLDB
@ConfigOption(name = "openmldb.disable", doc = "Disable OpenMLDB optimization or not")
var disableOpenmldb = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object RowProjectPlan {

val openmldbJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath

val outputDf = if (ctx.getConf.enableUnsafeRowOptimization) { // Use UnsafeRow optimization
val outputDf = if (ctx.getConf.enableUnsafeRowOptForProject) { // Use UnsafeRow optimization

val outputInternalRowRdd = inputDf.queryExecution.toRdd.mapPartitions(partitionIter => {
val tag = projectConfig.moduleTag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object WindowAggPlan {
// Check if we should keep the index column
val isKeepIndexColumn = SparkInstance.keepIndexColumn(ctx, physicalNode.GetNodeId())
// Check if use UnsafeRow optimizaiton or not
val isUnsafeRowOptimization = ctx.getConf.enableUnsafeRowOptimization
val isUnsafeRowOptimization = ctx.getConf.enableUnsafeRowOptForWindow
// Check if we should keep the index column
val isWindowSkewOptimization = ctx.getConf.enableWindowSkewOpt

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.end2end.unsafe

import com._4paradigm.openmldb.batch.SparkTestSuite
import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.end2end.DataUtil
import com._4paradigm.openmldb.batch.utils.SparkUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class TestUnsafeGroupby extends SparkTestSuite {

override def customizedBefore(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", true)
}

test("Test unsafe groupby") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val df = DataUtil.getTestDf(spark)
sess.registerTable("t1", df)
df.createOrReplaceTempView("t1")

val sqlText = "SELECT max(id) AS max_id, sum(trans_amount) AS sum_amount FROM t1 GROUP BY name"

val outputDf = sess.sql(sqlText)
val sparksqlOutputDf = sess.sparksql(sqlText)
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
}

override def customizedAfter(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", false)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.end2end.unsafe

import com._4paradigm.openmldb.batch.SparkTestSuite
import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.end2end.DataUtil
import com._4paradigm.openmldb.batch.utils.SparkUtil

class TestUnsafeJoin extends SparkTestSuite {

override def customizedBefore(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", true)
}

test("Test unsafe join") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val df = DataUtil.getTestDf(spark)
sess.registerTable("t1", df)
sess.registerTable("t2", df)
df.createOrReplaceTempView("t1")
df.createOrReplaceTempView("t2")

val sqlText = "SELECT t1.id as t1_id, t2.id as t2_id, t1.name FROM t1 LEFT JOIN t2 ON t1.id = t2.id"

val outputDf = sess.sql(sqlText)
val sparksqlOutputDf = sess.sparksql(sqlText)
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
}

override def customizedAfter(): Unit = {
val spark = getSparkSession
spark.conf.set("spark.openmldb.unsaferow.opt", false)
}

}