Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize join condition expr #1502

Merged
4 changes: 4 additions & 0 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,10 @@ class BinaryExpr : public ExprNode {

Status InferAttr(ExprAnalysisContext *ctx) override;

static BinaryExpr *CastFrom(ExprNode *node) {
return dynamic_cast<BinaryExpr *>(node);
}

private:
FnOperator op_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ class OpenmldbBatchConfig extends Serializable {
@ConfigOption(name = "openmldb.opt.unsaferow.groupby", doc = "Enable UnsafeRow optimization for groupby")
var enableUnsafeRowOptForGroupby = false

@ConfigOption(name = "openmldb.opt.unsaferow.join", doc = "Enable UnsafeRow optimization for join")
var enableUnsafeRowOptForJoin = false
// Join optimization
@ConfigOption(name = "openmldb.opt.join.spark_expr", doc = "Enable join with original Spark expression")
var enableJoinWithSparkExpr = true

// Switch for disable OpenMLDB
@ConfigOption(name = "openmldb.disable", doc = "Disable OpenMLDB optimization or not")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException
import com._4paradigm.hybridse.node.{ConstNode, ExprType, DataType => HybridseDataType}
import com._4paradigm.hybridse.vm.PhysicalConstProjectNode
import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance}
import com._4paradigm.openmldb.batch.utils.HybridseUtil
import com._4paradigm.openmldb.batch.utils.{ExpressionUtil, HybridseUtil}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{lit, to_date, to_timestamp, typedLit, when}
import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType,
IntegerType, LongType, ShortType, StringType, TimestampType}
import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType,
StringType, TimestampType}

import scala.collection.JavaConverters.asScalaBufferConverter

Expand All @@ -51,7 +51,7 @@ object ConstProjectPlan {
val outputColName = outputColNameList(i)

// Create simple literal Spark column
val column = getConstCol(constNode)
val column = ExpressionUtil.constExprToSparkColumn(constNode)

// Match column type for output type
castSparkOutputCol(column, constNode.GetDataType(), outputColTypeList(i))
Expand All @@ -68,33 +68,7 @@ object ConstProjectPlan {
SparkInstance.createConsideringIndex(ctx, node.GetNodeId(), result)
}

// Generate Spark column from const value
def getConstCol(constNode: ConstNode): Column = {
constNode.GetDataType() match {
case HybridseDataType.kNull => lit(null)

case HybridseDataType.kInt16 =>
typedLit[Short](constNode.GetAsInt16())

case HybridseDataType.kInt32 =>
typedLit[Int](constNode.GetAsInt32())

case HybridseDataType.kInt64 =>
typedLit[Long](constNode.GetAsInt64())

case HybridseDataType.kFloat =>
typedLit[Float](constNode.GetAsFloat())

case HybridseDataType.kDouble =>
typedLit[Double](constNode.GetAsDouble())

case HybridseDataType.kVarchar =>
typedLit[String](constNode.GetAsString())

case _ => throw new UnsupportedHybridSeException(
s"Const value for HybridSE type ${constNode.GetDataType()} not supported")
}
}

def castSparkOutputCol(inputCol: Column,
fromType: HybridseDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package com._4paradigm.openmldb.batch.nodes
import com._4paradigm.hybridse.`type`.TypeOuterClass.ColumnDef
import com._4paradigm.hybridse.codec
import com._4paradigm.hybridse.codec.RowView
import com._4paradigm.hybridse.sdk.{HybridSeException, JitManager, SerializableByteBuffer}
import com._4paradigm.hybridse.node.{ExprListNode, JoinType}
import com._4paradigm.hybridse.sdk.{HybridSeException, JitManager, SerializableByteBuffer, UnsupportedHybridSeException}
import com._4paradigm.hybridse.node.{BinaryExpr, ExprListNode, ExprType, FnOperator, JoinType}
import com._4paradigm.hybridse.vm.{CoreAPI, HybridSeJitWrapper, PhysicalJoinNode}
import com._4paradigm.openmldb.batch.utils.{HybridseUtil, SparkColumnUtil, SparkRowUtil, SparkUtil}
import com._4paradigm.openmldb.batch.utils.{ExpressionUtil, HybridseUtil, SparkColumnUtil, SparkRowUtil, SparkUtil}
import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance, SparkRowCodec}
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -67,13 +67,16 @@ object JoinPlan {

val indexName = "__JOIN_INDEX__" + System.currentTimeMillis()

var hasIndexColumn = false

val leftDf: DataFrame = {
if (joinType == JoinType.kJoinTypeLeft) {
left.getDf()
} else {
if (supportNativeLastJoin && ctx.getConf.enableNativeLastJoin) {
left.getDf()
} else {
hasIndexColumn = true
// Add index column for original last join, not used in native last join
SparkUtil.addIndexColumn(spark, left.getDf(), indexName, ctx.getConf.addIndexColumnMethod)
}
Expand All @@ -96,41 +99,46 @@ object JoinPlan {
}
}

val indexColIdx = if (joinType == JoinType.kJoinTypeLast) {
leftDf.schema.size - 1
} else if (supportNativeLastJoin && ctx.getConf.enableNativeLastJoin) {
val indexColIdx = if (hasIndexColumn) {
leftDf.schema.size - 1
} else {
leftDf.schema.size
-1
}

val filter = node.join().condition()
// extra conditions
if (filter.condition() != null) {
val regName = "SPARKFE_JOIN_CONDITION_" + filter.fn_info().fn_name()
val conditionUDF = new JoinConditionUDF(
functionName = filter.fn_info().fn_name(),
inputSchemaSlices = inputSchemaSlices,
outputSchema = filter.fn_info().fn_schema(),
moduleTag = ctx.getTag,
moduleBroadcast = ctx.getSerializableModuleBuffer,
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath
)
spark.udf.register(regName, conditionUDF)

// Handle the duplicated column names to get Spark Column by index
val allColumns = new mutable.ArrayBuffer[Column]()
for (i <- leftDf.schema.indices) {
if (i != indexColIdx) {
allColumns += SparkColumnUtil.getColumnFromIndex(leftDf, i)
if (ctx.getConf.enableJoinWithSparkExpr) {
joinConditions += ExpressionUtil.recusiveGetSparkColumnFromExpr(filter.condition(), node, leftDf, rightDf,
hasIndexColumn)
logger.info("Generate spark join conditions: " + joinConditions)
} else { // Disable join with native expression, use encoder/decoder and jit function
val regName = "SPARKFE_JOIN_CONDITION_" + filter.fn_info().fn_name()
val conditionUDF = new JoinConditionUDF(
functionName = filter.fn_info().fn_name(),
inputSchemaSlices = inputSchemaSlices,
outputSchema = filter.fn_info().fn_schema(),
moduleTag = ctx.getTag,
moduleBroadcast = ctx.getSerializableModuleBuffer,
hybridseJsdkLibraryPath = ctx.getConf.openmldbJsdkLibraryPath
)
spark.udf.register(regName, conditionUDF)

// Handle the duplicated column names to get Spark Column by index
val allColumns = new mutable.ArrayBuffer[Column]()
for (i <- leftDf.schema.indices) {
if (i != indexColIdx) {
allColumns += SparkColumnUtil.getColumnFromIndex(leftDf, i)
}
}
}
for (i <- rightDf.schema.indices) {
allColumns += SparkColumnUtil.getColumnFromIndex(rightDf, i)
for (i <- rightDf.schema.indices) {
allColumns += SparkColumnUtil.getColumnFromIndex(rightDf, i)
}

val allColWrap = functions.struct(allColumns: _*)
joinConditions += functions.callUDF(regName, allColWrap)
}

val allColWrap = functions.struct(allColumns: _*)
joinConditions += functions.callUDF(regName, allColWrap)
}

if (joinConditions.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com._4paradigm.openmldb.batch.nodes
import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException
import com._4paradigm.hybridse.node.{CastExprNode, ConstNode, ExprNode, ExprType, DataType => HybridseDataType}
import com._4paradigm.hybridse.vm.{CoreAPI, PhysicalSimpleProjectNode}
import com._4paradigm.openmldb.batch.utils.{HybridseUtil, SparkColumnUtil}
import com._4paradigm.openmldb.batch.utils.{ExpressionUtil, HybridseUtil, SparkColumnUtil}
import com._4paradigm.openmldb.batch.{PlanContext, SparkInstance}
import org.apache.spark.sql.{Column, DataFrame}
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -100,7 +100,7 @@ object SimpleProjectPlan {

case ExprType.kExprPrimary =>
val const = ConstNode.CastFrom(expr)
ConstProjectPlan.getConstCol(const) -> const.GetDataType
ExpressionUtil.constExprToSparkColumn(const) -> const.GetDataType

case ExprType.kExprCast =>
val cast = CastExprNode.CastFrom(expr)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.utils

import com._4paradigm.hybridse.node.{BinaryExpr, ConstNode, ExprNode, ExprType, FnOperator,
DataType => HybridseDataType}
import com._4paradigm.hybridse.sdk.UnsupportedHybridSeException
import com._4paradigm.hybridse.vm.{CoreAPI, PhysicalOpNode}
import org.apache.spark.sql.functions.{lit, typedLit}
import org.apache.spark.sql.{Column, DataFrame}

object ExpressionUtil {

/**
* Convert const expression to Spark Column object.
*
* @param constNode
* @return
*/
def constExprToSparkColumn(constNode: ConstNode): Column = {
constNode.GetDataType() match {
case HybridseDataType.kNull => lit(null)

case HybridseDataType.kInt16 =>
typedLit[Short](constNode.GetAsInt16())

case HybridseDataType.kInt32 =>
typedLit[Int](constNode.GetAsInt32())

case HybridseDataType.kInt64 =>
typedLit[Long](constNode.GetAsInt64())

case HybridseDataType.kFloat =>
typedLit[Float](constNode.GetAsFloat())

case HybridseDataType.kDouble =>
typedLit[Double](constNode.GetAsDouble())

case HybridseDataType.kVarchar =>
typedLit[String](constNode.GetAsString())

case _ => throw new UnsupportedHybridSeException(
s"Const value for HybridSE type ${constNode.GetDataType()} not supported")
}
}

/**
* Convert expr object to Spark Column object.
* Notice that this only works for some non-computing expressions.
*
* @param expr
* @param leftDf
* @param rightDf
* @param physicalNode
* @return
*/
def exprToSparkColumn(expr: ExprNode,
leftDf: DataFrame,
rightDf: DataFrame,
physicalNode: PhysicalOpNode,
hasIndexColumn: Boolean): Column = {
expr.GetExprType() match {
case ExprType.kExprColumnRef | ExprType.kExprColumnId =>
val inputNode = physicalNode
val colIndex = CoreAPI.ResolveColumnIndex(inputNode, expr)
if (colIndex < 0 || colIndex >= inputNode.GetOutputSchemaSize()) {
inputNode.Print()
inputNode.PrintSchema()
throw new IndexOutOfBoundsException(
s"${expr.GetExprString()} resolved index out of bound: $colIndex")
}

if (hasIndexColumn) {
if (colIndex < leftDf.schema.size - 1) {
// Get from left df
SparkColumnUtil.getColumnFromIndex(leftDf, colIndex)
} else {
// Get from right df
val rightColIndex = colIndex - (leftDf.schema.size - 1)
SparkColumnUtil.getColumnFromIndex(rightDf, rightColIndex)
}
} else {
if (colIndex < leftDf.schema.size) {
// Get from left df
SparkColumnUtil.getColumnFromIndex(leftDf, colIndex)
} else {
// Get from right df
val rightColIndex = colIndex - leftDf.schema.size
SparkColumnUtil.getColumnFromIndex(rightDf, rightColIndex)
}
}

case ExprType.kExprPrimary =>
val const = ConstNode.CastFrom(expr)
ExpressionUtil.constExprToSparkColumn(const)

case _ => throw new UnsupportedHybridSeException(
s"Do not support converting expression to Spark Column for expression type ${expr.GetExprType}")
}
}

/**
* Convert binary expression to two Spark Column objects.
*
* @param binaryExpr
* @param physicalNode
* @param leftDf
* @param rightDf
* @return
*/
def binaryExprToSparkColumns(binaryExpr: BinaryExpr, physicalNode: PhysicalOpNode, leftDf: DataFrame,
rightDf: DataFrame, hasIndexColumn: Boolean): (Column, Column) = {
val leftExpr = binaryExpr.GetChild(0)
val rightExpr = binaryExpr.GetChild(1)
val leftSparkColumn = ExpressionUtil.exprToSparkColumn(leftExpr, leftDf, rightDf, physicalNode, hasIndexColumn)
val rightSparkColumn = ExpressionUtil.exprToSparkColumn(rightExpr, leftDf, rightDf, physicalNode, hasIndexColumn)
leftSparkColumn -> rightSparkColumn
}


def recusiveGetSparkColumnFromExpr(expr: ExprNode, node: PhysicalOpNode, leftDf: DataFrame,
rightDf: DataFrame, hasIndexColumn: Boolean): Column = {
expr.GetExprType() match {
case ExprType.kExprBinary =>
val binaryExpr = BinaryExpr.CastFrom(expr)
val op = binaryExpr.GetOp()
op match {
case FnOperator.kFnOpAnd =>
// TODO(tobe): Only support for binary sub expressions
val leftExpr = BinaryExpr.CastFrom(binaryExpr.GetChild(0))
val rightExpr = BinaryExpr.CastFrom(binaryExpr.GetChild(1))
val leftColumn = recusiveGetSparkColumnFromExpr(leftExpr, node, leftDf, rightDf, hasIndexColumn)
val rightColumn = recusiveGetSparkColumnFromExpr(rightExpr, node, leftDf, rightDf, hasIndexColumn)
leftColumn.and(rightColumn)
case FnOperator.kFnOpOr =>
val leftExpr = BinaryExpr.CastFrom(binaryExpr.GetChild(0))
val rightExpr = BinaryExpr.CastFrom(binaryExpr.GetChild(1))
val leftColumn = recusiveGetSparkColumnFromExpr(leftExpr, node, leftDf, rightDf, hasIndexColumn)
val rightColumn = recusiveGetSparkColumnFromExpr(rightExpr, node, leftDf, rightDf, hasIndexColumn)
leftColumn.or(rightColumn)
case FnOperator.kFnOpNot =>
!recusiveGetSparkColumnFromExpr(expr, node, leftDf, rightDf, hasIndexColumn)
case FnOperator.kFnOpEq => // TODO(todo): Support null-safe equal in the future
// Notice that it may be handled by physical plan's left_key() and right_key()
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left.equalTo(right)
case FnOperator.kFnOpNeq =>
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left.notEqual(right)
case FnOperator.kFnOpLt =>
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left < right
case FnOperator.kFnOpLe =>
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left <= right
case FnOperator.kFnOpGt =>
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left > right
case FnOperator.kFnOpGe =>
val (left, right) = ExpressionUtil.binaryExprToSparkColumns(binaryExpr, node, leftDf, rightDf,
hasIndexColumn)
left >= right
}
case _ => throw new UnsupportedHybridSeException(
s"Does not support convert expression type ${expr.GetExprType} to Spark Column")
}

}

}
Loading