Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-207] fix left/right outer join in SMJ (#416)
Browse files Browse the repository at this point in the history
* fix the loss of stream data

* use SMJ in WSCG

* fix left/right outer join

* disable inplace with project

* refine
  • Loading branch information
rui-mo authored Jul 23, 2021
1 parent b49b0a2 commit 4da9887
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit._

import com.intel.oap.ColumnarPluginConfig
import com.intel.oap.vectorized.ArrowWritableColumnVector

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -37,7 +36,8 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import scala.collection.JavaConverters._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}

import scala.collection.mutable.ListBuffer
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
Expand All @@ -46,20 +46,19 @@ import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.gandiva.evaluator._

import org.apache.arrow.memory.ArrowBuf

import com.google.common.collect.Lists;

import com.google.common.collect.Lists
import org.apache.spark.sql.types.{DataType, DecimalType, StructType}
import com.intel.oap.vectorized.ExpressionEvaluator
import com.intel.oap.vectorized.BatchIterator
import org.apache.spark.sql.util.ArrowUtils

/**
* Performs a sort merge join of two child relations.
*/
class ColumnarSortMergeJoin(
prober: ExpressionEvaluator,
build_input_arrow_schema: Schema,
stream_input_arrow_schema: Schema,
output_arrow_schema: Schema,
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -110,9 +109,43 @@ class ColumnarSortMergeJoin(
}
build_cb = realbuildIter.next()
val beforeBuild = System.nanoTime()
val projectedBuildKeyCols: List[ArrowWritableColumnVector] = if (buildProjector != null) {
val builderOrdinalList = buildProjector.getOrdinalList()
val builderAttributes = buildProjector.output()
val builderProjectCols = builderOrdinalList.map(i => {
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
buildProjector.evaluate(build_cb.numRows, builderProjectCols.map(_.getValueVector()))
} else {
List[ArrowWritableColumnVector]()
}
val buildCols = (0 until build_cb.numCols).toList.map(i =>
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedBuildKeyCols
val build_rb =
ConverterUtils.createArrowRecordBatch(build_cb.numRows, buildCols.map(_.getValueVector))

buildCols.indices.toList.foreach(i =>
buildCols(i).retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
ConverterUtils.releaseArrowRecordBatch(build_rb)
projectedBuildKeyCols.foreach(v => v.close())
}
if (build_cb != null) {
build_cb = null
} else {
if ((joinType == LeftOuter || joinType == RightOuter) && realstreamIter.hasNext) {
// If stream side still has next batch, an empty batch is assigned to build side.
val resultColumnVectors = ArrowWritableColumnVector
.allocateColumns(0, ArrowUtils.fromArrowSchema(build_input_arrow_schema))
.toArray
build_cb =
new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
val beforeBuild = System.nanoTime()
val projectedBuildKeyCols: List[ArrowWritableColumnVector] = if (buildProjector != null) {
val builderOrdinalList = buildProjector.getOrdinalList
val builderAttributes = buildProjector.output
val builderOrdinalList = buildProjector.getOrdinalList()
val builderAttributes = buildProjector.output()
val builderProjectCols = builderOrdinalList.map(i => {
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
Expand All @@ -124,31 +157,26 @@ class ColumnarSortMergeJoin(
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedBuildKeyCols
val build_rb =
ConverterUtils.createArrowRecordBatch(build_cb.numRows, buildCols.map(_.getValueVector))

(0 until buildCols.size).toList.foreach(i =>
buildCols(i).retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
ConverterUtils.releaseArrowRecordBatch(build_rb)
projectedBuildKeyCols.foreach(v => v.close)
}
if (build_cb != null) {
build_cb = null
} else {
val res = new Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
false
}

override def next(): ColumnarBatch = {
val resultColumnVectors = ArrowWritableColumnVector
.allocateColumns(0, resultSchema)
.toArray
new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
buildCols.indices.toList.foreach(i => buildCols(i).retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
ConverterUtils.releaseArrowRecordBatch(build_rb)
projectedBuildKeyCols.foreach(v => v.close)
} else {
val res = new Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
false
}
override def next(): ColumnarBatch = {
val resultColumnVectors = ArrowWritableColumnVector
.allocateColumns(0, resultSchema)
.toArray
new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
}
return res
}
return res
}
val beforeBuild = System.nanoTime()
probe_iterator = prober.finishByIterator()
Expand Down Expand Up @@ -633,6 +661,7 @@ object ColumnarSortMergeJoin extends Logging {

new ColumnarSortMergeJoin(
prober,
build_input_arrow_schema,
stream_input_arrow_schema,
output_arrow_schema,
leftKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ case class ColumnarCollapseCodegenStages(
if (count >= 1) true
else plan.children.map(existsJoins(_, count + 1)).exists(_ == true)
case p: ColumnarSortMergeJoinExec =>
if (p.joinType.isInstanceOf[ExistenceJoin]) true
if (count >= 1) true
else plan.children.map(existsJoins(_, count + 1)).exists(_ == true)
true
case p: ColumnarHashAggregateExec =>
if (count >= 1) true
else plan.children.map(existsJoins(_, count + 1)).exists(_ == true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
/** For Debug Use only
* List of test cases to test, in lower cases. */
protected def testList: Set[String] = Set(

"udf/postgreSQL/udf-aggregates_part1.sql"
)

/** List of test cases to ignore, in lower cases. */
Expand All @@ -167,7 +167,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
/** segfault, compilation error and exception */

"postgreSQL/window_part3.sql", // WindowSortKernel::Impl::GetCompFunction_
"subquery/in-subquery/in-joins.sql", // NullPointerException
"subquery/in-subquery/in-joins.sql", // NullPointerException: LocalTableScanExec.stringArgs
"udf/postgreSQL/udf-aggregates_part1.sql", // IllegalStateException: Value at index is null

/** Cannot reproduce */
Expand All @@ -192,7 +192,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
// "explain-aqe.sql", // plan check
// "explain.sql", // plan check
// "describe.sql", // AnalysisException
// "subquery/scalar-subquery/scalar-subquery-select.sql", // SMJ LeftAnti
// "ansi/decimalArithmeticOperations.sql",
// "postgreSQL/union.sql", // aggregate-groupby
// "postgreSQL/int4.sql", // exception expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

ignore(s"$testName using SortMergeJoin") {
test(s"$testName using SortMergeJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ class ConditionedMergeJoinKernel::Impl {
auto streamed_relation = "sort_relation_" + std::to_string(relation_id_[1]) + "_";
auto left_index_name = "left_index_" + relation_id;
auto right_index_name = "right_index_" + relation_id;
auto outer_join_num_matches_name = "outer_join_num_matches_" + relation_id;

///// Get Matched row /////
codes_ss << "int " << range_name << " = 0;" << std::endl;
Expand Down Expand Up @@ -449,16 +450,14 @@ class ConditionedMergeJoinKernel::Impl {
codes_ss << right_for_loop_codes.str();
codes_ss << "auto is_smj_" << relation_id << " = false;" << std::endl;
}
codes_ss << "int " << outer_join_num_matches_name << " = 0;" << std::endl;
codes_ss << "for (int " << range_id << " = 0; " << range_id << " < " << range_name
<< "; " << range_id << "++) {" << std::endl;
codes_ss << "if(!" << fill_null_name << "){" << std::endl;
codes_ss << "if(" << function_name << "_res == 0"
<< ") {" << std::endl;
codes_ss << left_index_name << " = " << build_relation << "->GetItemIndexWithShift("
<< range_id << ");" << std::endl;
codes_ss << "}" << std::endl;
if (!cache_right) {
codes_ss << "auto is_smj_" << relation_id << " = false;" << std::endl;
codes_ss << right_for_loop_codes.str();
}
codes_ss << fill_null_name << " = false;" << std::endl;
if (cond_check) {
auto condition_name = "ConditionCheck_" + std::to_string(relation_id_[0]);
if (use_relation_for_stream) {
Expand All @@ -468,8 +467,18 @@ class ConditionedMergeJoinKernel::Impl {
codes_ss << "if (!" << condition_name << "(" << left_index_name << ")) {"
<< std::endl;
}
codes_ss << fill_null_name << " = true;" << std::endl;
codes_ss << "}" << std::endl;
codes_ss << "if ((" << range_id << " + 1) == " << range_name << " && "
<< outer_join_num_matches_name << " == 0) {" << std::endl;
codes_ss << fill_null_name << " = true; } else {" << std::endl;
codes_ss << "continue;" << std::endl;
codes_ss << "}}" << std::endl;
}
codes_ss << "} else {" << std::endl;
codes_ss << fill_null_name << " = true;}" << std::endl;
codes_ss << outer_join_num_matches_name << " += 1;" << std::endl;
if (!cache_right) {
codes_ss << "auto is_smj_" << relation_id << " = false;" << std::endl;
codes_ss << right_for_loop_codes.str();
}
finish_codes_ss << "} // end of Outer Join" << std::endl;
if (use_relation_for_stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ SortArraysToIndicesKernel::SortArraysToIndicesKernel(

if (key_field_list.size() == 1 && result_schema->num_fields() == 1 &&
key_field_list[0]->type()->id() != arrow::Type::STRING &&
key_field_list[0]->type()->id() != arrow::Type::BOOL) {
key_field_list[0]->type()->id() != arrow::Type::BOOL && !pre_processed_key_) {
// Will use SortInplace when sorting for one non-string and non-boolean col
#ifdef DEBUG
std::cout << "UseSortInplace" << std::endl;
Expand Down
2 changes: 0 additions & 2 deletions native-sql-engine/cpp/src/third_party/datetime/date.h
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,6 @@ std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>&

#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline namespace literals {

CONSTCD11 date::day operator"" _d(unsigned long long d) NOEXCEPT;
CONSTCD11 date::year operator"" _y(unsigned long long y) NOEXCEPT;

Expand Down Expand Up @@ -1603,7 +1602,6 @@ inline std::basic_ostream<CharT, Traits>& operator<<(

#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline namespace literals {

CONSTCD11
inline date::day operator"" _d(unsigned long long d) NOEXCEPT {
return date::day{static_cast<unsigned>(d)};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
*
* contributor license agreements. See the NOTICE file distributed with
* contributor license agreements. See the NOTICE file distributed with
* this
* work for additional information regarding copyright ownership.
* work for additional information regarding copyright ownership.
* The ASF
* licenses this file to You under the Apache License, Version 2.0
* licenses this file to You under the Apache License, Version 2.0
* (the
* "License"); you may not use this file except in compliance with
* "License"); you may not use this file except in compliance with
* the
* License. You may obtain a copy of the License at
*
* License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by
* applicable law or agreed to in writing, software
* applicable law or agreed to in writing, software
* distributed under the
* License is distributed on an "AS IS" BASIS,
* License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied.
* CONDITIONS OF ANY KIND, either express or implied.
* See the License for the
* specific language governing permissions and
* specific language governing permissions and
* limitations under the
* License.
* License.
*/
#ifndef __NATIVE_MEMORY_H
#define __NATIVE_MEMORY_H
Expand Down

0 comments on commit 4da9887

Please sign in to comment.