Skip to content

Commit

Permalink
Add tests for codegen fallback inside HOF
Browse files Browse the repository at this point in the history
  • Loading branch information
Kimahriman committed Nov 25, 2024
1 parent 8440d62 commit 5c82d9c
Showing 1 changed file with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -149,6 +151,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

val plusOne: Expression => Expression = x => x + 1
val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i
val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1)

checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4))
checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5))
Expand All @@ -158,6 +161,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6))
checkEvaluation(transform(ain, plusOne), null)

checkEvaluation(transform(ai0, plusOneFallback), Seq(2, 3, 4))

val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
Expand Down Expand Up @@ -277,6 +282,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 }
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)

checkEvaluation(filter(ai0, isEven), Seq(2))
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
Expand All @@ -286,6 +292,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(filter(ain, isEven), null)
checkEvaluation(filter(ain, isNullOrOdd), null)

checkEvaluation(filter(ai0, isEvenFallback), Seq(2))

val as0 =
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
Expand Down Expand Up @@ -321,6 +329,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral
val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType)
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)

for (followThreeValuedLogic <- Seq(false, true)) {
withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key
Expand All @@ -337,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(exists(ain, isNullOrOdd), null)
checkEvaluation(exists(ain, alwaysFalse), null)
checkEvaluation(exists(ain, alwaysNull), null)
checkEvaluation(exists(ai0, isEvenFallback), true)
}
}

Expand Down Expand Up @@ -383,6 +393,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral
val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType)
val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0)

checkEvaluation(forall(ai0, isEven), true)
checkEvaluation(forall(ai0, isNullOrOdd), false)
Expand All @@ -401,6 +412,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(forall(ain, alwaysFalse), null)
checkEvaluation(forall(ain, alwaysNull), null)

checkEvaluation(forall(ai0, isEvenFallback), true)

val as0 =
Literal.create(Seq("a0", "a1", "a2", "a3"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true))
Expand Down Expand Up @@ -886,3 +899,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
)))
}
}

case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
override lazy val resolved = child.resolved
override def eval(input: InternalRow): Any = child.eval(input)
override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr =
copy(child = newChild)
}

0 comments on commit 5c82d9c

Please sign in to comment.