Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48280][SQL] Improve collation testing surface area using expression walking #46801

Closed
wants to merge 39 commits into from
Closed
Changes from 12 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3be0b6a
Add test
mihailom-db May 15, 2024
d630120
Enable more functions
mihailom-db May 15, 2024
357334e
Improve test for expression walking
mihailom-db May 30, 2024
8a95bed
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db May 30, 2024
56532d4
Add more functions
mihailom-db May 30, 2024
af1268e
Fix null problem
mihailom-db May 31, 2024
716e778
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jun 3, 2024
f5012ec
Fix conflicts
mihailom-db Jun 3, 2024
73be32b
Remove unused inports
mihailom-db Jun 3, 2024
394f85e
Remove prints
mihailom-db Jun 4, 2024
698fbcf
Fix trailing comma error
mihailom-db Jun 4, 2024
2c47eaf
Add polishing
mihailom-db Jun 4, 2024
ba680db
Add new Suite
mihailom-db Jun 5, 2024
2f3fc4c
Revert changes in CollationSuite
mihailom-db Jun 5, 2024
e4ea17d
Refactor code
mihailom-db Jun 5, 2024
263c141
Add MapType support
mihailom-db Jun 6, 2024
29bb400
Add support for StructType
mihailom-db Jun 6, 2024
55f84da
Remove unnecessary prints
mihailom-db Jun 6, 2024
ba90ca5
Improve comment
mihailom-db Jun 6, 2024
7017d80
Improve comment
mihailom-db Jun 6, 2024
497baa5
Add example walker
mihailom-db Jun 10, 2024
4e7b611
Add new test
mihailom-db Jun 11, 2024
61bd63f
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jun 12, 2024
8cdb7ad
Add codeGen test
mihailom-db Jun 13, 2024
776dcba
Fix test errors
mihailom-db Jun 13, 2024
ced5500
Add new test
mihailom-db Jun 14, 2024
2c98c8d
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jun 17, 2024
9446fd1
Enable fixed expressions
mihailom-db Jun 18, 2024
8fa75ce
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jun 20, 2024
9902b05
Polish code and improve tests
mihailom-db Jun 20, 2024
a789400
Merge remote-tracking branch 'refs/remotes/upstream/master' into SPAR…
mihailom-db Jul 1, 2024
e75ff5c
Incorporate changes
mihailom-db Jul 1, 2024
001f91d
Improve code
mihailom-db Jul 2, 2024
9082b5e
remove printing
mihailom-db Jul 2, 2024
10a358b
Run CIs with printing
mihailom-db Jul 2, 2024
9da174d
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jul 2, 2024
6729bc6
Merge remote-tracking branch 'upstream/master' into SPARK-48280
mihailom-db Jul 3, 2024
57d11c8
Remove printing
mihailom-db Jul 3, 2024
bfbc9dc
Merge branch 'apache:master' into SPARK-48280
mihailom-db Jul 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 210 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
class CollationSuite extends DatasourceV2SQLBase
with AdaptiveSparkPlanHelper with ExpressionEvalHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName

private val collationPreservingSources = Seq("parquet")
Expand Down Expand Up @@ -948,6 +952,210 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("SPARK-48280: Expression Walker for Testing") {
// This test does following:
// 1) Take all expressions
// 2) Filter out ones that have at least one argument of StringType
// 3) Use reflection to create an instance of the expression using first constructor
// (test other as well).
// 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader)
// 5) Run eval against literals with strings under:
// a) UTF8_BINARY, "dummy string" as input.
// b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input.
// 6) Check if both expressions throw an exception.
// 7) If no exception, check if the result is the same.
// 8) There is a list of allowed expressions that can differ (e.g. hex)
def hasStringType(inputType: AbstractDataType): Boolean = {
inputType match {
case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType =>
true
case ArrayType => true
case ArrayType(elementType, _) => hasStringType(elementType)
case AbstractArrayType(elementType) => hasStringType(elementType)
case TypeCollection(typeCollection) =>
typeCollection.exists(hasStringType)
case _ => false
}
}

val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId =>
spark.sessionState.catalog.lookupFunctionInfo(funcId)
}.filter(funInfo => {
// make sure that there is a constructor.
val cl = Utils.classForName(funInfo.getClassName)
!cl.getConstructors.isEmpty
}).filter(funInfo => {
val className = funInfo.getClassName
val cl = Utils.classForName(funInfo.getClassName)
// dummy instance
// Take first constructor.
val headConstructor = cl.getConstructors.head

val params = headConstructor.getParameters.map(p => p.getType)
val allExpressions = params.forall(p => p.isAssignableFrom(classOf[Expression]) ||
p.isAssignableFrom(classOf[Seq[Expression]]) ||
p.isAssignableFrom(classOf[Option[Expression]]))

if (!allExpressions) {
false
} else {
val args = params.map {
case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1")
case se if se.isAssignableFrom(classOf[Seq[Expression]]) =>
Seq(Literal.create("1"), Literal.create("2"))
case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None
}
// Find all expressions that have string as input
try {
val expr = headConstructor.newInstance(args: _*)
expr match {
case types: ExpectsInputTypes =>
val inputTypes = types.inputTypes
// check if this is a collection...
inputTypes.exists(hasStringType)
}
} catch {
case _: Throwable => false
}
}
}).toArray

// Helper methods for generating data.
sealed trait CollationType
case object Utf8Binary extends CollationType
case object Utf8BinaryLcase extends CollationType

def generateSingleEntry(
inputType: AbstractDataType,
collationType: CollationType): Expression =
inputType match {
// Try to make this a bit more random.
case AnyTimestampType => Literal("2009-07-30 12:58:59")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to organize this a bit better. I think that in future this logic may become more complex (e.g. we don't want to just pass 1 and "dummy_string". Instead we will try with different string shapes + special rules for integers (-1, 0, 1, strlen, strlen + 1...).

Again, my recommendation is to add new class for expression walker and define this logic as methods of that class.

case BinaryType => Literal(new Array[Byte](5))
case BooleanType => Literal(true)
case _: DatetimeType => Literal(1L)
case _: DecimalType => Literal(new Decimal)
case IntegerType | NumericType => Literal(1)
case LongType => Literal(1L)
case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType =>
collationType match {
case Utf8Binary =>
Literal.create("dummy string", StringType("UTF8_BINARY"))
case Utf8BinaryLcase =>
Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE"))
}
case TypeCollection(typeCollection) =>
val strTypes = typeCollection.filter(hasStringType)
if (strTypes.isEmpty) {
// Take first type
generateSingleEntry(typeCollection.head, collationType)
} else {
// Take first string type
generateSingleEntry(strTypes.head, collationType)
}
case AbstractArrayType(elementType) =>
generateSingleEntry(elementType, collationType).map(
lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType))
).head
case ArrayType(elementType, _) =>
generateSingleEntry(elementType, collationType).map(
lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType))
).head
case ArrayType =>
generateSingleEntry(StringTypeAnyCollation, collationType).map(
lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType))
).head
}

def generateData(
inputTypes: Seq[AbstractDataType],
collationType: CollationType): Seq[Expression] = {
inputTypes.map(generateSingleEntry(_, collationType))
}

val toSkip = List(
"get_json_object",
"map_zip_with",
"printf",
"transform_keys",
"concat_ws",
"format_string",
"session_window",
"transform_values",
"arrays_zip",
"hex" // this is fine
)

for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
val cl = Utils.classForName(f.getClassName)
val headConstructor = cl.getConstructors.head
val params = headConstructor.getParameters.map(p => p.getType)
val paramCount = params.length
val args = params.map {
case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1")
case se if se.isAssignableFrom(classOf[Seq[Expression]]) =>
Seq(Literal.create("1"))
case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None
}
val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes]
val inputTypes = expr.inputTypes

var nones = Array.fill(0)(None)
if (paramCount > inputTypes.length) {
nones = Array.fill(paramCount - inputTypes.length)(None)
}

val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) ++ nones
val instanceUtf8Binary =
headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression]

val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) ++ nones
val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression]

val exceptionUtfBinary = {
try {
instanceUtf8Binary.eval(EmptyRow)
None
} catch {
case e: Throwable => Some(e)
}
}

val exceptionLcase = {
try {
instanceLcase.eval(EmptyRow)
None
} catch {
case e: Throwable => Some(e)
}
}

// Check that both cases either throw or pass
assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined)

if (exceptionUtfBinary.isEmpty) {
val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow)
val resUtf8Lcase = instanceLcase.eval(EmptyRow)

val dt = instanceLcase.dataType

dt match {
case _: StringType if resUtf8Lcase != null && resUtf8Lcase != null =>
assert(resUtf8Binary.isInstanceOf[UTF8String])
assert(resUtf8Lcase.isInstanceOf[UTF8String])
// scalastyle:off caselocale
assert(resUtf8Binary.asInstanceOf[UTF8String].toLowerCase.binaryEquals(
resUtf8Lcase.asInstanceOf[UTF8String].toLowerCase))
// scalastyle:on caselocale
case _ => resUtf8Lcase === resUtf8Binary
}
}
else {
assert(exceptionUtfBinary.get.getClass == exceptionLcase.get.getClass)
}
}
}

test("Support operations on complex types containing collated strings") {
checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba")))
checkAnswer(sql(
Expand Down