Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-1019] fix codegen for all expressions #1039

Merged
merged 10 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
extends Add(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "add"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
// because actually res type of "cast" is the res type in "add/subtract",
Expand Down Expand Up @@ -103,6 +110,14 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
with ColumnarExpression
with Logging {

val gName = "subtract"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
Expand Down Expand Up @@ -163,6 +178,14 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
with ColumnarExpression
with Logging {

val gName = "multiply"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
Expand Down Expand Up @@ -247,16 +270,20 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
}
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
return left_val.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) && right_val.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}
}

class ColumnarDivide(left: Expression, right: Expression,
original: Expression, resType: DecimalType = null)
extends Divide(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "divide"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

val left_val: Any = left match {
case c: ColumnarCast =>
Expand Down Expand Up @@ -349,6 +376,15 @@ class ColumnarBitwiseAnd(left: Expression, right: Expression, original: Expressi
extends BitwiseAnd(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val gName = "bitwise_and"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -370,6 +406,15 @@ class ColumnarBitwiseOr(left: Expression, right: Expression, original: Expressio
extends BitwiseOr(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val gName = "bitwise_or"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ class ColumnarGetJsonObject(left: Expression, right: Expression, original: GetJs
class ColumnarStringInstr(left: Expression, right: Expression, original: StringInstr)
extends StringInstr(original.str, original.substr) with ColumnarExpression with Logging {

val gName = "locate"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (left_node, _): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ class ColumnarStartsWith(left: Expression, right: Expression, original: Expressi
extends StartsWith(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "starts_with"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -121,6 +129,14 @@ class ColumnarLike(left: Expression, right: Expression, original: Expression)
with ColumnarExpression
with Logging {

val gName = "like"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

buildCheck()

def buildCheck(): Unit = {
Expand Down Expand Up @@ -156,7 +172,14 @@ class ColumnarRLike(left: Expression, right: Expression, original: Expression)
extends RLike(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "equal"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

buildCheck()

def buildCheck(): Unit = {
Expand All @@ -166,10 +189,6 @@ class ColumnarRLike(left: Expression, right: Expression, original: Expression)
}
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down Expand Up @@ -208,6 +227,14 @@ class ColumnarEqualTo(left: Expression, right: Expression, original: Expression)
extends EqualTo(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "equal"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down Expand Up @@ -307,8 +334,11 @@ class ColumnarLessThan(left: Expression, right: Expression, original: Expression
with ColumnarExpression
with Logging {

val gName = "less_than"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
true && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

Expand Down Expand Up @@ -354,6 +384,14 @@ class ColumnarLessThanOrEqual(left: Expression, right: Expression, original: Exp
extends LessThanOrEqual(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "less_than_or_equal_to"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down Expand Up @@ -398,6 +436,15 @@ class ColumnarGreaterThan(left: Expression, right: Expression, original: Express
extends GreaterThan(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val gName = "greater_than"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down Expand Up @@ -442,6 +489,14 @@ class ColumnarGreaterThanOrEqual(left: Expression, right: Expression, original:
extends GreaterThanOrEqual(left: Expression, right: Expression)
with ColumnarExpression
with Logging {
val gName = "greater_than_or_equal_to"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down Expand Up @@ -486,6 +541,15 @@ class ColumnarShiftLeft(left: Expression, right: Expression, original: Expressio
extends ShiftLeft(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val gName = "shift_left"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -508,6 +572,15 @@ class ColumnarShiftRight(left: Expression, right: Expression, original: Expressi
extends ShiftRight(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val gName = "shift_right"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._
import scala.collection.mutable.ListBuffer

/**
* A version of substring that supports columnar processing for utf8.
* Columnar impl for Case When base on Gandiva if/else
*/
class ColumnarCaseWhen(
branches: Seq[(Expression, Expression)],
Expand Down Expand Up @@ -64,9 +64,6 @@ class ColumnarCaseWhen(
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
logInfo(s"children: ${branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue}")
logInfo(s"branches: $branches")
logInfo(s"else: $elseValue")
val i = 0
val exprs = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
val exprList = { exprs.filter(expr => !expr.isInstanceOf[Literal]) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression)
with ColumnarExpression
with Logging {

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
for (expr <- exps) {
val colExpr = ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (!colExpr.asInstanceOf[ColumnarExpression].supportColumnarCodegen(Lists.newArrayList())) {
return false
}
}
return true
}

buildCheck()

def buildCheck(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ object ColumnarDateTimeExpressions {

class ColumnarYear(child: Expression) extends Year(child) with
ColumnarExpression {
val gName = "extractYear"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
child.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

buildCheck()

Expand Down Expand Up @@ -420,6 +426,12 @@ object ColumnarDateTimeExpressions {

class ColumnarMicrosToTimestamp(child: Expression) extends MicrosToTimestamp(child) with
ColumnarExpression {
val gName = "micros_to_timestamp"

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
codegenFuncList.contains(gName) &&
child.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

buildCheck()

Expand Down Expand Up @@ -453,6 +465,8 @@ object ColumnarDateTimeExpressions {
extends UnixTimestamp(left, right, timeZoneId, failOnError) with
ColumnarExpression {

val gName = "unix_seconds"

val yearMonthDayFormat = "yyyy-MM-dd"
val yearMonthDayTimeFormat = "yyyy-MM-dd HH:mm:ss"
val yearMonthDayTimeNoSepFormat = "yyyyMMddHHmmss"
Expand Down Expand Up @@ -488,7 +502,8 @@ object ColumnarDateTimeExpressions {
}

override def supportColumnarCodegen(args: Object): Boolean = {
false && left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
codegenFuncList.contains(gName) &&
left.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args) &&
right.asInstanceOf[ColumnarExpression].supportColumnarCodegen(args)
}

Expand Down
Loading