From 3be0b6a00f88fcad2a695cf494b3fae1fe50736f Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 15 May 2024 08:53:44 +0200 Subject: [PATCH 01/30] Add test --- .../org/apache/spark/sql/CollationSuite.scala | 198 +++++++++++++++++- 1 file changed, 195 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b22a762a29547..cac312c2e7cba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, ExpectsInputTypes, Expression, ExpressionEvalHelper, Literal} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -32,9 +32,13 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, IntegerType, LongType, MapType, NumericType, StringType, StructField, StructType, TimestampType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils -class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { +class CollationSuite extends DatasourceV2SQLBase + with AdaptiveSparkPlanHelper with ExpressionEvalHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName private val collationPreservingSources = Seq("parquet") @@ -979,6 +983,194 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("SPARK-48280: Expression Walker for Testing") { + // This test does following: + // 1) Take all expressions + // 2) filter out ones that have at least one argument of string type + // 3) Use reflection to create an instance of the expression using first constructor + // (test other as well). + // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) + // 5) Run eval against literals with strings under: + // a) UTF8_BINARY, "dummy string" as input. + // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. + // 6) Check if both expressions throw an exception. + // 7) If no exception, check if the result is the same. + // 8) There is a list of allowed expressions that can differ (e.g. hex) + // + // We currently capture 75/449 expressions. We could do better. + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + !cl.getConstructors.isEmpty + }).filter(funInfo => { + val className = funInfo.getClassName + // noinspection ScalaStyle + // println("checking - " + className) + val cl = Utils.classForName(funInfo.getClassName) + // dummy instance + // Take first constructor. + val headConstructor = cl.getConstructors.head + + val paramCount = headConstructor.getParameterCount + val allExpressions = headConstructor.getParameters.map(p => p.getType) + .forall(p => p.isAssignableFrom(classOf[Expression])) + + if (!allExpressions) { + false + } else { + val args = Array.fill(paramCount)(Literal.create(1)) + // Find all expressions that have string as input + try { + val expr = headConstructor.newInstance(args: _*) + expr match { + case types: ExpectsInputTypes => + val inputTypes = types.inputTypes + // check if this is a collection... + inputTypes.exists { + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => true + case TypeCollection(typeCollection) => + typeCollection.exists { + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => true + case _ => false + } + case _ => false + } + case _ => + // Check other expressions here... + false + } + } catch { + // TODO: Try to get rid of this... + case _: Throwable => + false + } + } + }).toArray + + // noinspection ScalaStyle + println("Found total of " + funInfos.size + " functions") + // 75/449 + // We could capture more probably... + + // Helper methods for generating data. + sealed trait CollationType + case object Utf8Binary extends CollationType + case object Utf8BinaryLcase extends CollationType + + // TODO: There is probably some nicer way to do this... + def generateData( + types: Seq[AbstractDataType], + collationType: CollationType): Seq[Expression] = { + types.map { + case TypeCollection(typeCollection) => + val strTypes = + typeCollection.filter(dt => dt.isInstanceOf[StringType] || + dt == StringTypeAnyCollation) + if (strTypes.isEmpty) { + // Take any + generateData(typeCollection, collationType).head + } else { + generateData(strTypes, collationType).head + } + case _: StringType | StringTypeAnyCollation => + collationType match { + case Utf8Binary => + Literal.create("dummy string", StringType("UTF8_BINARY")) + case Utf8BinaryLcase => + Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) + } + // Try to make this a bit more random. + case IntegerType | NumericType => Literal(1) + case TimestampType | LongType => Literal(1L) + case AbstractArrayType(elementType) => + (elementType, collationType) match { + case (StringTypeAnyCollation, Utf8Binary) => + Literal.create(Seq("dummy string"), ArrayType(StringType("UTF8_BINARY"))) + case (StringTypeAnyCollation, Utf8BinaryLcase) => + Literal.create(Seq("dUmmY sTriNg"), ArrayType(StringType("UTF8_BINARY_LCASE"))) + case (_, _) => fail("unsupported type") + } + } + } + + val toSkip = List( + "next_day", // TODO: Add support/debug these. + "regexp_replace", + "trunc", + "aes_encrypt", // this is probably fine? + "convert_timezone", + "substring", // TODO: this is test issue + "aes_decrypt", + "str_to_map", + "get_json_object", + "make_timestamp", + "overlay", + "hex", // this is fine + ) + + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + // noinspection ScalaStyle + println(f.getName) + + val cl = Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors.head + val paramCount = headConstructor.getParameterCount + val args = Array.fill(paramCount)(Literal(1)) + val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] + val inputTypes = expr.inputTypes + + val inputDataUtf8Binary = generateData(inputTypes, Utf8Binary) + val instanceUtf8Binary = + headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] + + val inputDataLcase = generateData(inputTypes, Utf8BinaryLcase) + val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] + + val exceptionUtfBinary = { + try { + instanceUtf8Binary.eval(EmptyRow) + None + } catch { + case e: Throwable => Some(e) + } + } + + val exceptionLcase = { + try { + instanceLcase.eval(EmptyRow) + None + } catch { + case e: Throwable => Some(e) + } + } + + // if exception, assert that both cases have exception. + // TODO: check if exception is the same. + assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) + + // no exception - check result. + if (exceptionUtfBinary.isEmpty) { + val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) + val resUtf8Lcase = instanceLcase.eval(EmptyRow) + + val dt = instanceLcase.dataType + + dt match { + case st: StringType => + assert(resUtf8Binary.isInstanceOf[UTF8String]) + assert(resUtf8Lcase.isInstanceOf[UTF8String]) + // scalastyle:off caselocale + assert(resUtf8Binary.asInstanceOf[UTF8String].toLowerCase.binaryEquals( + resUtf8Lcase.asInstanceOf[UTF8String].toLowerCase)) + // scalastyle:on caselocale + case _ => resUtf8Lcase === resUtf8Binary + } + } + } + } + test("Support operations on complex types containing collated strings") { checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba"))) checkAnswer(sql( From d63012057e39c4337ca0cbfeeb6952ddc0735205 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 15 May 2024 12:48:59 +0200 Subject: [PATCH 02/30] Enable more functions --- .../org/apache/spark/sql/CollationSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index cac312c2e7cba..493f8d3c406d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAg import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, IntegerType, LongType, MapType, NumericType, StringType, StructField, StructType, TimestampType, TypeCollection} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -1007,7 +1007,7 @@ class CollationSuite extends DatasourceV2SQLBase }).filter(funInfo => { val className = funInfo.getClassName // noinspection ScalaStyle - // println("checking - " + className) + println("checking - " + className) val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. @@ -1018,6 +1018,8 @@ class CollationSuite extends DatasourceV2SQLBase .forall(p => p.isAssignableFrom(classOf[Expression])) if (!allExpressions) { + // noinspection ScalaStyle + println("NotAll") false } else { val args = Array.fill(paramCount)(Literal.create(1)) @@ -1039,11 +1041,15 @@ class CollationSuite extends DatasourceV2SQLBase } case _ => // Check other expressions here... + // noinspection ScalaStyle + println("NotExpects") false } } catch { // TODO: Try to get rid of this... case _: Throwable => + // noinspection ScalaStyle + println("ErrorsOut") false } } @@ -1067,14 +1073,14 @@ class CollationSuite extends DatasourceV2SQLBase case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(dt => dt.isInstanceOf[StringType] || - dt == StringTypeAnyCollation) + dt == StringTypeAnyCollation || dt == StringTypeBinaryLcase) if (strTypes.isEmpty) { // Take any generateData(typeCollection, collationType).head } else { generateData(strTypes, collationType).head } - case _: StringType | StringTypeAnyCollation => + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => collationType match { case Utf8Binary => Literal.create("dummy string", StringType("UTF8_BINARY")) @@ -1083,7 +1089,10 @@ class CollationSuite extends DatasourceV2SQLBase } // Try to make this a bit more random. case IntegerType | NumericType => Literal(1) - case TimestampType | LongType => Literal(1L) + case LongType => Literal(1L) + case _: DecimalType => Literal(new Decimal) + case BinaryType => Literal(new Array[Byte](5)) + case dt if dt.isInstanceOf[DatetimeType] => Literal(1L) case AbstractArrayType(elementType) => (elementType, collationType) match { case (StringTypeAnyCollation, Utf8Binary) => @@ -1097,16 +1106,7 @@ class CollationSuite extends DatasourceV2SQLBase val toSkip = List( "next_day", // TODO: Add support/debug these. - "regexp_replace", - "trunc", - "aes_encrypt", // this is probably fine? - "convert_timezone", - "substring", // TODO: this is test issue - "aes_decrypt", - "str_to_map", "get_json_object", - "make_timestamp", - "overlay", "hex", // this is fine ) @@ -1121,11 +1121,11 @@ class CollationSuite extends DatasourceV2SQLBase val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] val inputTypes = expr.inputTypes - val inputDataUtf8Binary = generateData(inputTypes, Utf8Binary) + val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) val instanceUtf8Binary = headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] - val inputDataLcase = generateData(inputTypes, Utf8BinaryLcase) + val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] val exceptionUtfBinary = { From 357334eaf3e13d3b69d369ef9f9f54b60862299a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 30 May 2024 08:08:27 +0200 Subject: [PATCH 03/30] Improve test for expression walking --- .../org/apache/spark/sql/CollationSuite.scala | 153 ++++++++++++------ 1 file changed, 102 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 493f8d3c406d8..178234c957e55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, ExpectsInputTypes, Expression, ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, EmptyRow, ExpectsInputTypes, Expression, ExpressionEvalHelper, InheritAnalysisRules, Literal} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -986,7 +986,7 @@ class CollationSuite extends DatasourceV2SQLBase test("SPARK-48280: Expression Walker for Testing") { // This test does following: // 1) Take all expressions - // 2) filter out ones that have at least one argument of string type + // 2) Filter out ones that have at least one argument of StringType // 3) Use reflection to create an instance of the expression using first constructor // (test other as well). // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) @@ -998,6 +998,19 @@ class CollationSuite extends DatasourceV2SQLBase // 8) There is a list of allowed expressions that can differ (e.g. hex) // // We currently capture 75/449 expressions. We could do better. + def hasStringType(inputType: AbstractDataType): Boolean = { + inputType match { + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + true + case ArrayType => true + case ArrayType(elementType, _) => hasStringType(elementType) + case AbstractArrayType(elementType) => hasStringType(elementType) + case TypeCollection(typeCollection) => + typeCollection.exists(hasStringType) + case _ => false + } + } + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) }.filter(funInfo => { @@ -1013,32 +1026,42 @@ class CollationSuite extends DatasourceV2SQLBase // Take first constructor. val headConstructor = cl.getConstructors.head - val paramCount = headConstructor.getParameterCount - val allExpressions = headConstructor.getParameters.map(p => p.getType) - .forall(p => p.isAssignableFrom(classOf[Expression])) + val params = headConstructor.getParameters.map(p => p.getType) + val allExpressions = params.forall(p => p.isAssignableFrom(classOf[Expression]) || + p.isAssignableFrom(classOf[Seq[Expression]]) || + p.isAssignableFrom(classOf[Option[Expression]])) if (!allExpressions) { // noinspection ScalaStyle println("NotAll") false } else { - val args = Array.fill(paramCount)(Literal.create(1)) + val args = params.map { + case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case se if se.isAssignableFrom(classOf[Seq[Expression]]) => + Seq(Literal.create("1"), Literal.create("2")) + case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None + } // Find all expressions that have string as input try { val expr = headConstructor.newInstance(args: _*) expr match { case types: ExpectsInputTypes => + // noinspection ScalaStyle + println("AllExpects") val inputTypes = types.inputTypes // check if this is a collection... - inputTypes.exists { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => true - case TypeCollection(typeCollection) => - typeCollection.exists { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => true - case _ => false - } - case _ => false - } + inputTypes.exists(hasStringType) + case _: ComplexTypeMergingExpression => + // Check other expressions here... + // noinspection ScalaStyle + println("TypeForMerging") + false + case _: InheritAnalysisRules => + // Check other expressions here... + // noinspection ScalaStyle + println("Inherit") + false case _ => // Check other expressions here... // noinspection ScalaStyle @@ -1057,56 +1080,71 @@ class CollationSuite extends DatasourceV2SQLBase // noinspection ScalaStyle println("Found total of " + funInfos.size + " functions") - // 75/449 - // We could capture more probably... // Helper methods for generating data. sealed trait CollationType case object Utf8Binary extends CollationType case object Utf8BinaryLcase extends CollationType - // TODO: There is probably some nicer way to do this... - def generateData( - types: Seq[AbstractDataType], - collationType: CollationType): Seq[Expression] = { - types.map { - case TypeCollection(typeCollection) => - val strTypes = - typeCollection.filter(dt => dt.isInstanceOf[StringType] || - dt == StringTypeAnyCollation || dt == StringTypeBinaryLcase) - if (strTypes.isEmpty) { - // Take any - generateData(typeCollection, collationType).head - } else { - generateData(strTypes, collationType).head - } - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase => + def generateSingleEntry( + inputType: AbstractDataType, + collationType: CollationType): Expression = + inputType match { + // Try to make this a bit more random. + case AnyTimestampType => Literal("2009-07-30 12:58:59") + case BinaryType => Literal(new Array[Byte](5)) + case BooleanType => Literal(true) + case _: DatetimeType => Literal(1L) + case _: DecimalType => Literal(new Decimal) + case IntegerType | NumericType => Literal(1) + case LongType => Literal(1L) + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => collationType match { case Utf8Binary => Literal.create("dummy string", StringType("UTF8_BINARY")) case Utf8BinaryLcase => Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) } - // Try to make this a bit more random. - case IntegerType | NumericType => Literal(1) - case LongType => Literal(1L) - case _: DecimalType => Literal(new Decimal) - case BinaryType => Literal(new Array[Byte](5)) - case dt if dt.isInstanceOf[DatetimeType] => Literal(1L) - case AbstractArrayType(elementType) => - (elementType, collationType) match { - case (StringTypeAnyCollation, Utf8Binary) => - Literal.create(Seq("dummy string"), ArrayType(StringType("UTF8_BINARY"))) - case (StringTypeAnyCollation, Utf8BinaryLcase) => - Literal.create(Seq("dUmmY sTriNg"), ArrayType(StringType("UTF8_BINARY_LCASE"))) - case (_, _) => fail("unsupported type") + case TypeCollection(typeCollection) => + val strTypes = typeCollection.filter(hasStringType) + if (strTypes.isEmpty) { + // Take first type + generateSingleEntry(typeCollection.head, collationType) + } else { + // Take first string type + generateSingleEntry(strTypes.head, collationType) } - } + case AbstractArrayType(elementType) => + generateSingleEntry(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType(elementType, _) => + generateSingleEntry(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType => + generateSingleEntry(StringTypeAnyCollation, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head } + def generateData( + inputTypes: Seq[AbstractDataType], + collationType: CollationType): Seq[Expression] = { + inputTypes.map(generateSingleEntry(_, collationType)) + } + val toSkip = List( "next_day", // TODO: Add support/debug these. "get_json_object", + "map_zip_with", + "printf", + "transform_keys", + "concat_ws", + "format_string", + "session_window", + "transform_values", + "arrays_zip", "hex", // this is fine ) @@ -1116,16 +1154,27 @@ class CollationSuite extends DatasourceV2SQLBase val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head - val paramCount = headConstructor.getParameterCount - val args = Array.fill(paramCount)(Literal(1)) + val params = headConstructor.getParameters.map(p => p.getType) + val paramCount = params.length + val args = params.map { + case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case se if se.isAssignableFrom(classOf[Seq[Expression]]) => + Seq(Literal.create("1")) + case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None + } val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] val inputTypes = expr.inputTypes - val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) + var nones = Array.fill(0)(None) + if (paramCount > inputTypes.length) { + nones = Array.fill(paramCount - inputTypes.length)(None) + } + + val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) ++ nones val instanceUtf8Binary = headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] - val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) + val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) ++ nones val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] val exceptionUtfBinary = { @@ -1152,6 +1201,8 @@ class CollationSuite extends DatasourceV2SQLBase // no exception - check result. if (exceptionUtfBinary.isEmpty) { + // scalastyle:off println + println("GOODPASS") val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) val resUtf8Lcase = instanceLcase.eval(EmptyRow) From 56532d45857cef40b17a71f57dda38cc1c8061d6 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 30 May 2024 13:51:18 +0200 Subject: [PATCH 04/30] Add more functions --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index e58eead9bceba..bb4ab46de902e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1101,7 +1101,6 @@ class CollationSuite extends DatasourceV2SQLBase } val toSkip = List( - "next_day", // TODO: Add support/debug these. "get_json_object", "map_zip_with", "printf", From af1268e55afb78910157582ff174c783ee8dcafc Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 31 May 2024 10:17:04 +0200 Subject: [PATCH 05/30] Fix null problem --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index bb4ab46de902e..4e98b1b575513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1174,7 +1174,7 @@ class CollationSuite extends DatasourceV2SQLBase val dt = instanceLcase.dataType dt match { - case st: StringType => + case st: StringType if st != null => assert(resUtf8Binary.isInstanceOf[UTF8String]) assert(resUtf8Lcase.isInstanceOf[UTF8String]) // scalastyle:off caselocale From f5012ec01dc641f89873be969bcba81e0aac2199 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 3 Jun 2024 19:18:55 +0200 Subject: [PATCH 06/30] Fix conflicts --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 8f69e55ec9bd7..f8aae34d093bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1177,7 +1177,7 @@ class CollationSuite extends DatasourceV2SQLBase val dt = instanceLcase.dataType dt match { - case st: StringType if st != null => + case _: StringType if resUtf8Lcase != null && resUtf8Lcase != null => assert(resUtf8Binary.isInstanceOf[UTF8String]) assert(resUtf8Lcase.isInstanceOf[UTF8String]) // scalastyle:off caselocale From 73be32b40adf6dd5a1e76d40e9fadef576e72be2 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 3 Jun 2024 19:56:06 +0200 Subject: [PATCH 07/30] Remove unused inports --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index f8aae34d093bb..0ec4650a0d2d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils From 394f85e16181b944c5417ce4761902f969178f88 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 4 Jun 2024 07:59:31 +0200 Subject: [PATCH 08/30] Remove prints --- .../org/apache/spark/sql/CollationSuite.scala | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 0ec4650a0d2d7..21ad47f23da00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -988,8 +988,6 @@ class CollationSuite extends DatasourceV2SQLBase !cl.getConstructors.isEmpty }).filter(funInfo => { val className = funInfo.getClassName - // noinspection ScalaStyle - println("checking - " + className) val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. @@ -1001,8 +999,6 @@ class CollationSuite extends DatasourceV2SQLBase p.isAssignableFrom(classOf[Option[Expression]])) if (!allExpressions) { - // noinspection ScalaStyle - println("NotAll") false } else { val args = params.map { @@ -1016,40 +1012,25 @@ class CollationSuite extends DatasourceV2SQLBase val expr = headConstructor.newInstance(args: _*) expr match { case types: ExpectsInputTypes => - // noinspection ScalaStyle - println("AllExpects") val inputTypes = types.inputTypes // check if this is a collection... inputTypes.exists(hasStringType) case _: ComplexTypeMergingExpression => - // Check other expressions here... - // noinspection ScalaStyle - println("TypeForMerging") false case _: InheritAnalysisRules => - // Check other expressions here... - // noinspection ScalaStyle - println("Inherit") false case _ => // Check other expressions here... - // noinspection ScalaStyle - println("NotExpects") false } } catch { // TODO: Try to get rid of this... case _: Throwable => - // noinspection ScalaStyle - println("ErrorsOut") false } } }).toArray - // noinspection ScalaStyle - println("Found total of " + funInfos.size + " functions") - // Helper methods for generating data. sealed trait CollationType case object Utf8Binary extends CollationType @@ -1117,8 +1098,6 @@ class CollationSuite extends DatasourceV2SQLBase ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - // noinspection ScalaStyle - println(f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head @@ -1169,8 +1148,6 @@ class CollationSuite extends DatasourceV2SQLBase // no exception - check result. if (exceptionUtfBinary.isEmpty) { - // scalastyle:off println - println("GOODPASS") val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) val resUtf8Lcase = instanceLcase.eval(EmptyRow) From 698fbcfd51088c917d630e42d8294f32a4b1587f Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 4 Jun 2024 08:22:58 +0200 Subject: [PATCH 09/30] Fix trailing comma error --- .../org/apache/spark/sql/CollationSuite.scala | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 21ad47f23da00..1fe57efb33ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -988,6 +988,8 @@ class CollationSuite extends DatasourceV2SQLBase !cl.getConstructors.isEmpty }).filter(funInfo => { val className = funInfo.getClassName + // noinspection ScalaStyle + println("checking - " + className) val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. @@ -999,6 +1001,8 @@ class CollationSuite extends DatasourceV2SQLBase p.isAssignableFrom(classOf[Option[Expression]])) if (!allExpressions) { + // noinspection ScalaStyle + println("NotAll") false } else { val args = params.map { @@ -1012,25 +1016,40 @@ class CollationSuite extends DatasourceV2SQLBase val expr = headConstructor.newInstance(args: _*) expr match { case types: ExpectsInputTypes => + // noinspection ScalaStyle + println("AllExpects") val inputTypes = types.inputTypes // check if this is a collection... inputTypes.exists(hasStringType) case _: ComplexTypeMergingExpression => + // Check other expressions here... + // noinspection ScalaStyle + println("TypeForMerging") false case _: InheritAnalysisRules => + // Check other expressions here... + // noinspection ScalaStyle + println("Inherit") false case _ => // Check other expressions here... + // noinspection ScalaStyle + println("NotExpects") false } } catch { // TODO: Try to get rid of this... case _: Throwable => + // noinspection ScalaStyle + println("ErrorsOut") false } } }).toArray + // noinspection ScalaStyle + println("Found total of " + funInfos.size + " functions") + // Helper methods for generating data. sealed trait CollationType case object Utf8Binary extends CollationType @@ -1094,10 +1113,12 @@ class CollationSuite extends DatasourceV2SQLBase "session_window", "transform_values", "arrays_zip", - "hex", // this is fine + "hex" // this is fine ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + // noinspection ScalaStyle + println(f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head @@ -1148,6 +1169,8 @@ class CollationSuite extends DatasourceV2SQLBase // no exception - check result. if (exceptionUtfBinary.isEmpty) { + // scalastyle:off println + println("GOODPASS") val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) val resUtf8Lcase = instanceLcase.eval(EmptyRow) From 2c47eaf83dd08ca923abfa24dcd313b96736ffe8 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 4 Jun 2024 13:59:29 +0200 Subject: [PATCH 10/30] Add polishing --- .../org/apache/spark/sql/CollationSuite.scala | 44 +++---------------- 1 file changed, 5 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 1fe57efb33ef3..7e7e5ca62b1f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -965,8 +965,6 @@ class CollationSuite extends DatasourceV2SQLBase // 6) Check if both expressions throw an exception. // 7) If no exception, check if the result is the same. // 8) There is a list of allowed expressions that can differ (e.g. hex) - // - // We currently capture 75/449 expressions. We could do better. def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => @@ -988,8 +986,6 @@ class CollationSuite extends DatasourceV2SQLBase !cl.getConstructors.isEmpty }).filter(funInfo => { val className = funInfo.getClassName - // noinspection ScalaStyle - println("checking - " + className) val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. @@ -1001,8 +997,6 @@ class CollationSuite extends DatasourceV2SQLBase p.isAssignableFrom(classOf[Option[Expression]])) if (!allExpressions) { - // noinspection ScalaStyle - println("NotAll") false } else { val args = params.map { @@ -1016,40 +1010,16 @@ class CollationSuite extends DatasourceV2SQLBase val expr = headConstructor.newInstance(args: _*) expr match { case types: ExpectsInputTypes => - // noinspection ScalaStyle - println("AllExpects") val inputTypes = types.inputTypes // check if this is a collection... inputTypes.exists(hasStringType) - case _: ComplexTypeMergingExpression => - // Check other expressions here... - // noinspection ScalaStyle - println("TypeForMerging") - false - case _: InheritAnalysisRules => - // Check other expressions here... - // noinspection ScalaStyle - println("Inherit") - false - case _ => - // Check other expressions here... - // noinspection ScalaStyle - println("NotExpects") - false } } catch { - // TODO: Try to get rid of this... - case _: Throwable => - // noinspection ScalaStyle - println("ErrorsOut") - false + case _: Throwable => false } } }).toArray - // noinspection ScalaStyle - println("Found total of " + funInfos.size + " functions") - // Helper methods for generating data. sealed trait CollationType case object Utf8Binary extends CollationType @@ -1117,9 +1087,6 @@ class CollationSuite extends DatasourceV2SQLBase ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - // noinspection ScalaStyle - println(f.getName) - val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) @@ -1163,14 +1130,10 @@ class CollationSuite extends DatasourceV2SQLBase } } - // if exception, assert that both cases have exception. - // TODO: check if exception is the same. + // Check that both cases either throw or pass assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) - // no exception - check result. if (exceptionUtfBinary.isEmpty) { - // scalastyle:off println - println("GOODPASS") val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) val resUtf8Lcase = instanceLcase.eval(EmptyRow) @@ -1187,6 +1150,9 @@ class CollationSuite extends DatasourceV2SQLBase case _ => resUtf8Lcase === resUtf8Binary } } + else { + assert(exceptionUtfBinary.get.getClass == exceptionLcase.get.getClass) + } } } From ba680dbace00b407833b8039def1c4454c3d0d39 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 5 Jun 2024 09:29:16 +0200 Subject: [PATCH 11/30] Add new Suite --- .../sql/CollationExpressionWalkerSuite.scala | 232 ++++++++++++++++++ .../org/apache/spark/sql/CollationSuite.scala | 207 ---------------- 2 files changed, 232 insertions(+), 207 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala new file mode 100644 index 0000000000000..8c93683ead1c2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, ExpectsInputTypes, Expression, Literal} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, NumericType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSession { + test("SPARK-48280: Expression Walker for Testing") { + // This test does following: + // 1) Take all expressions + // 2) Filter out ones that have at least one argument of StringType + // 3) Use reflection to create an instance of the expression using first constructor + // (test other as well). + // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) + // 5) Run eval against literals with strings under: + // a) UTF8_BINARY, "dummy string" as input. + // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. + // 6) Check if both expressions throw an exception. + // 7) If no exception, check if the result is the same. + // 8) There is a list of allowed expressions that can differ (e.g. hex) + def hasStringType(inputType: AbstractDataType): Boolean = { + inputType match { + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + true + case ArrayType => true + case ArrayType(elementType, _) => hasStringType(elementType) + case AbstractArrayType(elementType) => hasStringType(elementType) + case TypeCollection(typeCollection) => + typeCollection.exists(hasStringType) + case _ => false + } + } + + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + !cl.getConstructors.isEmpty + }).filter(funInfo => { + val className = funInfo.getClassName + val cl = Utils.classForName(funInfo.getClassName) + // dummy instance + // Take first constructor. + val headConstructor = cl.getConstructors.head + + val params = headConstructor.getParameters.map(p => p.getType) + val allExpressions = params.forall(p => p.isAssignableFrom(classOf[Expression]) || + p.isAssignableFrom(classOf[Seq[Expression]]) || + p.isAssignableFrom(classOf[Option[Expression]])) + + if (!allExpressions) { + false + } else { + val args = params.map { + case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case se if se.isAssignableFrom(classOf[Seq[Expression]]) => + Seq(Literal.create("1"), Literal.create("2")) + case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None + } + // Find all expressions that have string as input + try { + val expr = headConstructor.newInstance(args: _*) + expr match { + case types: ExpectsInputTypes => + val inputTypes = types.inputTypes + // check if this is a collection... + inputTypes.exists(hasStringType) + } + } catch { + case _: Throwable => false + } + } + }).toArray + + // Helper methods for generating data. + sealed trait CollationType + case object Utf8Binary extends CollationType + case object Utf8BinaryLcase extends CollationType + + def generateSingleEntry( + inputType: AbstractDataType, + collationType: CollationType): Expression = + inputType match { + // Try to make this a bit more random. + case AnyTimestampType => Literal("2009-07-30 12:58:59") + case BinaryType => Literal(new Array[Byte](5)) + case BooleanType => Literal(true) + case _: DatetimeType => Literal(1L) + case _: DecimalType => Literal(new Decimal) + case IntegerType | NumericType => Literal(1) + case LongType => Literal(1L) + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + collationType match { + case Utf8Binary => + Literal.create("dummy string", StringType("UTF8_BINARY")) + case Utf8BinaryLcase => + Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) + } + case TypeCollection(typeCollection) => + val strTypes = typeCollection.filter(hasStringType) + if (strTypes.isEmpty) { + // Take first type + generateSingleEntry(typeCollection.head, collationType) + } else { + // Take first string type + generateSingleEntry(strTypes.head, collationType) + } + case AbstractArrayType(elementType) => + generateSingleEntry(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType(elementType, _) => + generateSingleEntry(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType => + generateSingleEntry(StringTypeAnyCollation, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + } + + def generateData( + inputTypes: Seq[AbstractDataType], + collationType: CollationType): Seq[Expression] = { + inputTypes.map(generateSingleEntry(_, collationType)) + } + + val toSkip = List( + "get_json_object", + "map_zip_with", + "printf", + "transform_keys", + "concat_ws", + "format_string", + "session_window", + "transform_values", + "arrays_zip", + "hex" // this is fine + ) + + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + val cl = Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors.head + val params = headConstructor.getParameters.map(p => p.getType) + val paramCount = params.length + val args = params.map { + case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case se if se.isAssignableFrom(classOf[Seq[Expression]]) => + Seq(Literal.create("1")) + case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None + } + val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] + val inputTypes = expr.inputTypes + + var nones = Array.fill(0)(None) + if (paramCount > inputTypes.length) { + nones = Array.fill(paramCount - inputTypes.length)(None) + } + + val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) ++ nones + val instanceUtf8Binary = + headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] + + val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) ++ nones + val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] + + val exceptionUtfBinary = { + try { + instanceUtf8Binary.eval(EmptyRow) + None + } catch { + case e: Throwable => Some(e) + } + } + + val exceptionLcase = { + try { + instanceLcase.eval(EmptyRow) + None + } catch { + case e: Throwable => Some(e) + } + } + + // Check that both cases either throw or pass + assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) + + if (exceptionUtfBinary.isEmpty) { + val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) + val resUtf8Lcase = instanceLcase.eval(EmptyRow) + + val dt = instanceLcase.dataType + + dt match { + case _: StringType if resUtf8Lcase != null && resUtf8Lcase != null => + assert(resUtf8Binary.isInstanceOf[UTF8String]) + assert(resUtf8Lcase.isInstanceOf[UTF8String]) + // scalastyle:off caselocale + assert(resUtf8Binary.asInstanceOf[UTF8String].toLowerCase.binaryEquals( + resUtf8Lcase.asInstanceOf[UTF8String].toLowerCase)) + // scalastyle:on caselocale + case _ => resUtf8Lcase === resUtf8Binary + } + } + else { + assert(exceptionUtfBinary.get.getClass == exceptionLcase.get.getClass) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 7e7e5ca62b1f8..f9c375b173065 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -32,10 +32,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper with ExpressionEvalHelper { @@ -952,210 +949,6 @@ class CollationSuite extends DatasourceV2SQLBase } } - test("SPARK-48280: Expression Walker for Testing") { - // This test does following: - // 1) Take all expressions - // 2) Filter out ones that have at least one argument of StringType - // 3) Use reflection to create an instance of the expression using first constructor - // (test other as well). - // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) - // 5) Run eval against literals with strings under: - // a) UTF8_BINARY, "dummy string" as input. - // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. - // 6) Check if both expressions throw an exception. - // 7) If no exception, check if the result is the same. - // 8) There is a list of allowed expressions that can differ (e.g. hex) - def hasStringType(inputType: AbstractDataType): Boolean = { - inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true - case ArrayType => true - case ArrayType(elementType, _) => hasStringType(elementType) - case AbstractArrayType(elementType) => hasStringType(elementType) - case TypeCollection(typeCollection) => - typeCollection.exists(hasStringType) - case _ => false - } - } - - val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => - spark.sessionState.catalog.lookupFunctionInfo(funcId) - }.filter(funInfo => { - // make sure that there is a constructor. - val cl = Utils.classForName(funInfo.getClassName) - !cl.getConstructors.isEmpty - }).filter(funInfo => { - val className = funInfo.getClassName - val cl = Utils.classForName(funInfo.getClassName) - // dummy instance - // Take first constructor. - val headConstructor = cl.getConstructors.head - - val params = headConstructor.getParameters.map(p => p.getType) - val allExpressions = params.forall(p => p.isAssignableFrom(classOf[Expression]) || - p.isAssignableFrom(classOf[Seq[Expression]]) || - p.isAssignableFrom(classOf[Option[Expression]])) - - if (!allExpressions) { - false - } else { - val args = params.map { - case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") - case se if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(Literal.create("1"), Literal.create("2")) - case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None - } - // Find all expressions that have string as input - try { - val expr = headConstructor.newInstance(args: _*) - expr match { - case types: ExpectsInputTypes => - val inputTypes = types.inputTypes - // check if this is a collection... - inputTypes.exists(hasStringType) - } - } catch { - case _: Throwable => false - } - } - }).toArray - - // Helper methods for generating data. - sealed trait CollationType - case object Utf8Binary extends CollationType - case object Utf8BinaryLcase extends CollationType - - def generateSingleEntry( - inputType: AbstractDataType, - collationType: CollationType): Expression = - inputType match { - // Try to make this a bit more random. - case AnyTimestampType => Literal("2009-07-30 12:58:59") - case BinaryType => Literal(new Array[Byte](5)) - case BooleanType => Literal(true) - case _: DatetimeType => Literal(1L) - case _: DecimalType => Literal(new Decimal) - case IntegerType | NumericType => Literal(1) - case LongType => Literal(1L) - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - collationType match { - case Utf8Binary => - Literal.create("dummy string", StringType("UTF8_BINARY")) - case Utf8BinaryLcase => - Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) - } - case TypeCollection(typeCollection) => - val strTypes = typeCollection.filter(hasStringType) - if (strTypes.isEmpty) { - // Take first type - generateSingleEntry(typeCollection.head, collationType) - } else { - // Take first string type - generateSingleEntry(strTypes.head, collationType) - } - case AbstractArrayType(elementType) => - generateSingleEntry(elementType, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - case ArrayType(elementType, _) => - generateSingleEntry(elementType, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - case ArrayType => - generateSingleEntry(StringTypeAnyCollation, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - } - - def generateData( - inputTypes: Seq[AbstractDataType], - collationType: CollationType): Seq[Expression] = { - inputTypes.map(generateSingleEntry(_, collationType)) - } - - val toSkip = List( - "get_json_object", - "map_zip_with", - "printf", - "transform_keys", - "concat_ws", - "format_string", - "session_window", - "transform_values", - "arrays_zip", - "hex" // this is fine - ) - - for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) - val headConstructor = cl.getConstructors.head - val params = headConstructor.getParameters.map(p => p.getType) - val paramCount = params.length - val args = params.map { - case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") - case se if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(Literal.create("1")) - case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None - } - val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] - val inputTypes = expr.inputTypes - - var nones = Array.fill(0)(None) - if (paramCount > inputTypes.length) { - nones = Array.fill(paramCount - inputTypes.length)(None) - } - - val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) ++ nones - val instanceUtf8Binary = - headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] - - val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) ++ nones - val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] - - val exceptionUtfBinary = { - try { - instanceUtf8Binary.eval(EmptyRow) - None - } catch { - case e: Throwable => Some(e) - } - } - - val exceptionLcase = { - try { - instanceLcase.eval(EmptyRow) - None - } catch { - case e: Throwable => Some(e) - } - } - - // Check that both cases either throw or pass - assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) - - if (exceptionUtfBinary.isEmpty) { - val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) - val resUtf8Lcase = instanceLcase.eval(EmptyRow) - - val dt = instanceLcase.dataType - - dt match { - case _: StringType if resUtf8Lcase != null && resUtf8Lcase != null => - assert(resUtf8Binary.isInstanceOf[UTF8String]) - assert(resUtf8Lcase.isInstanceOf[UTF8String]) - // scalastyle:off caselocale - assert(resUtf8Binary.asInstanceOf[UTF8String].toLowerCase.binaryEquals( - resUtf8Lcase.asInstanceOf[UTF8String].toLowerCase)) - // scalastyle:on caselocale - case _ => resUtf8Lcase === resUtf8Binary - } - } - else { - assert(exceptionUtfBinary.get.getClass == exceptionLcase.get.getClass) - } - } - } - test("Support operations on complex types containing collated strings") { checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba"))) checkAnswer(sql( From 2f3fc4c7715fbc4e47fa6ae62511808bff18be22 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 5 Jun 2024 09:32:00 +0200 Subject: [PATCH 12/30] Revert changes in CollationSuite --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index f9c375b173065..6576847bcc091 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -32,10 +32,9 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} -class CollationSuite extends DatasourceV2SQLBase - with AdaptiveSparkPlanHelper with ExpressionEvalHelper { +class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName private val collationPreservingSources = Seq("parquet") From e4ea17df5d05745d423502a50318161918e85c0a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 5 Jun 2024 15:13:21 +0200 Subject: [PATCH 13/30] Refactor code --- .../sql/CollationExpressionWalkerSuite.scala | 278 ++++++++++-------- 1 file changed, 162 insertions(+), 116 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 8c93683ead1c2..2c2d98be61e1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -18,18 +18,148 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, ExpectsInputTypes, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, Literal} import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, NumericType, StringType, TypeCollection} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, NumericType, StringType, StructType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils +/** + * This suite is introduced in order to test a bulk of expressions and functionalities related to + * collations + */ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSession { - test("SPARK-48280: Expression Walker for Testing") { + + // Trait to distinguish different cases for generation + sealed trait CollationType + + case object Utf8Binary extends CollationType + + case object Utf8BinaryLcase extends CollationType + + /** + * Helper function to generate all necesary parameters + * + * @param inputEntry - List of all input entries that need to be generated + * @param collationType - Flag defining collation type to use + * @return + */ + def generateData( + inputEntry: Seq[Any], + collationType: CollationType): Seq[Any] = { + inputEntry.map(generateSingleEntry(_, collationType)) + } + + /** + * Helper function to generate single entry of data. + * @param inputEntry - Single input entry that requires generation + * @param collationType - Flag defining collation type to use + * @return + */ + def generateSingleEntry( + inputEntry: Any, + collationType: CollationType): Any = + inputEntry match { + case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => + Seq(Literal.create("1"), Literal.create("2")) + case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Expression]]) => None + case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false + case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType + case st: Class[_] if st.isAssignableFrom(classOf[StructType]) => StructType + case em: Class[_] if em.isAssignableFrom(classOf[EvalMode.Value]) => EvalMode.LEGACY + case m: Class[_] if m.isAssignableFrom(classOf[Map[_, _]]) => Map.empty + case c: Class[_] if c.isAssignableFrom(classOf[Char]) => '\\' + case i: Class[_] if i.isAssignableFrom(classOf[Int]) => 0 + case l: Class[_] if l.isAssignableFrom(classOf[Long]) => 0 + case adt: AbstractDataType => generateLiterals(adt, collationType) + case Nil => Seq() + case (head: AbstractDataType) :: rest => generateData(head :: rest, collationType) + } + + /** + * Helper function to generate single literal from the given type. + * + * @param inputType - Single input literal type that requires generation + * @param collationType - Flag defining collation type to use + * @return + */ + def generateLiterals( + inputType: AbstractDataType, + collationType: CollationType): Expression = + inputType match { + // TODO: Try to make this a bit more random. + case AnyTimestampType => Literal("2009-07-30 12:58:59") + case BinaryType => Literal(new Array[Byte](5)) + case BooleanType => Literal(true) + case _: DatetimeType => Literal(1L) + case _: DecimalType => Literal(new Decimal) + case IntegerType | NumericType => Literal(1) + case LongType => Literal(1L) + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + collationType match { + case Utf8Binary => + Literal.create("dummy string", StringType("UTF8_BINARY")) + case Utf8BinaryLcase => + Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) + } + case TypeCollection(typeCollection) => + val strTypes = typeCollection.filter(hasStringType) + if (strTypes.isEmpty) { + // Take first type + generateLiterals(typeCollection.head, collationType) + } else { + // Take first string type + generateLiterals(strTypes.head, collationType) + } + case AbstractArrayType(elementType) => + generateLiterals(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType(elementType, _) => + generateLiterals(elementType, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + case ArrayType => + generateLiterals(StringTypeAnyCollation, collationType).map( + lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) + ).head + } + + /** + * Helper function to extract types of relevance + * @param inputType + * @return + */ + def hasStringType(inputType: AbstractDataType): Boolean = { + inputType match { + case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + true + case ArrayType => true + case ArrayType(elementType, _) => hasStringType(elementType) + case AbstractArrayType(elementType) => hasStringType(elementType) + case TypeCollection(typeCollection) => + typeCollection.exists(hasStringType) + case _ => false + } + } + + def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { + (inputTypes, params) match { + case (Nil, mparams) => mparams + case (_, Nil) => Nil + case (minputTypes, mparams) if mparams.head.isAssignableFrom(classOf[Expression]) => + minputTypes.head +: replaceExpressions(inputTypes.tail, mparams.tail) + case (minputTypes, mparams) => + mparams.head +: replaceExpressions(minputTypes.tail, mparams.tail) + } + } + + test("SPARK-48280: Expression Walker for Test") { // This test does following: // 1) Take all expressions - // 2) Filter out ones that have at least one argument of StringType + // 2) Find the ones that have at least one argument of StringType // 3) Use reflection to create an instance of the expression using first constructor // (test other as well). // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) @@ -39,19 +169,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi // 6) Check if both expressions throw an exception. // 7) If no exception, check if the result is the same. // 8) There is a list of allowed expressions that can differ (e.g. hex) - def hasStringType(inputType: AbstractDataType): Boolean = { - inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true - case ArrayType => true - case ArrayType(elementType, _) => hasStringType(elementType) - case AbstractArrayType(elementType) => hasStringType(elementType) - case TypeCollection(typeCollection) => - typeCollection.exists(hasStringType) - case _ => false - } - } - + var expressionCounter = 0 + var expectsExpressionCounter = 0; val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) }.filter(funInfo => { @@ -59,131 +178,58 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val cl = Utils.classForName(funInfo.getClassName) !cl.getConstructors.isEmpty }).filter(funInfo => { - val className = funInfo.getClassName + expressionCounter = expressionCounter + 1 val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) - val allExpressions = params.forall(p => p.isAssignableFrom(classOf[Expression]) || - p.isAssignableFrom(classOf[Seq[Expression]]) || - p.isAssignableFrom(classOf[Option[Expression]])) - - if (!allExpressions) { - false - } else { - val args = params.map { - case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") - case se if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(Literal.create("1"), Literal.create("2")) - case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None - } - // Find all expressions that have string as input - try { - val expr = headConstructor.newInstance(args: _*) - expr match { - case types: ExpectsInputTypes => - val inputTypes = types.inputTypes - // check if this is a collection... - inputTypes.exists(hasStringType) - } - } catch { - case _: Throwable => false + + val args = generateData(params.toSeq, Utf8Binary) + // Find all expressions that have string as input + try { + val expr = headConstructor.newInstance(args: _*) + expr match { + case types: ExpectsInputTypes => + expectsExpressionCounter = expectsExpressionCounter + 1 + val inputTypes = types.inputTypes + inputTypes.exists(hasStringType) } + } catch { + case _: Throwable => false } }).toArray - // Helper methods for generating data. - sealed trait CollationType - case object Utf8Binary extends CollationType - case object Utf8BinaryLcase extends CollationType - - def generateSingleEntry( - inputType: AbstractDataType, - collationType: CollationType): Expression = - inputType match { - // Try to make this a bit more random. - case AnyTimestampType => Literal("2009-07-30 12:58:59") - case BinaryType => Literal(new Array[Byte](5)) - case BooleanType => Literal(true) - case _: DatetimeType => Literal(1L) - case _: DecimalType => Literal(new Decimal) - case IntegerType | NumericType => Literal(1) - case LongType => Literal(1L) - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - collationType match { - case Utf8Binary => - Literal.create("dummy string", StringType("UTF8_BINARY")) - case Utf8BinaryLcase => - Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) - } - case TypeCollection(typeCollection) => - val strTypes = typeCollection.filter(hasStringType) - if (strTypes.isEmpty) { - // Take first type - generateSingleEntry(typeCollection.head, collationType) - } else { - // Take first string type - generateSingleEntry(strTypes.head, collationType) - } - case AbstractArrayType(elementType) => - generateSingleEntry(elementType, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - case ArrayType(elementType, _) => - generateSingleEntry(elementType, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - case ArrayType => - generateSingleEntry(StringTypeAnyCollation, collationType).map( - lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) - ).head - } - - def generateData( - inputTypes: Seq[AbstractDataType], - collationType: CollationType): Seq[Expression] = { - inputTypes.map(generateSingleEntry(_, collationType)) - } - val toSkip = List( - "get_json_object", "map_zip_with", - "printf", "transform_keys", - "concat_ws", - "format_string", - "session_window", + "session_window", // has too complex inputType "transform_values", - "arrays_zip", + "reduce", + "parse_url", // Parse URL is using wrong concepts "hex" // this is fine ) - + // scalastyle:off println + println("Total number of expression: " + expressionCounter) + println("Total number of expression that expect input: " + expectsExpressionCounter) + println("Number of extracted expressions of relevance: " + (funInfos.length - toSkip.length)) + // scalastyle:on println for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) - val paramCount = params.length - val args = params.map { - case e if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") - case se if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(Literal.create("1")) - case oe if oe.isAssignableFrom(classOf[Option[Expression]]) => None - } + val args = generateData(params.toSeq, Utf8Binary) val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] val inputTypes = expr.inputTypes - var nones = Array.fill(0)(None) - if (paramCount > inputTypes.length) { - nones = Array.fill(paramCount - inputTypes.length)(None) - } - - val inputDataUtf8Binary = generateData(inputTypes.take(paramCount), Utf8Binary) ++ nones + val inputDataUtf8Binary = + generateData(replaceExpressions(inputTypes, params.toSeq), Utf8Binary) val instanceUtf8Binary = headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] - val inputDataLcase = generateData(inputTypes.take(paramCount), Utf8BinaryLcase) ++ nones + val inputDataLcase = + generateData(replaceExpressions(inputTypes, params.toSeq), Utf8BinaryLcase) val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] val exceptionUtfBinary = { From 263c1416253195902bc22fb8a5932ff1692c49a3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 6 Jun 2024 08:54:36 +0200 Subject: [PATCH 14/30] Add MapType support --- .../sql/CollationExpressionWalkerSuite.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 2c2d98be61e1f..1c29b64120063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, Literal} import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, NumericType, StringType, StructType, TypeCollection} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, MapType, NumericType, StringType, StructType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -125,6 +125,14 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi generateLiterals(StringTypeAnyCollation, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head + case MapType => + val key = generateLiterals(StringTypeAnyCollation, collationType) + val value = generateLiterals(StringTypeAnyCollation, collationType) + Literal.create(Map(key -> value)) + case MapType(keyType, valueType, _) => + val key = generateLiterals(keyType, collationType) + val value = generateLiterals(valueType, collationType) + Literal.create(Map(key -> value)) } /** @@ -137,6 +145,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => true case ArrayType => true + case MapType => true + case MapType(keyType, valueType, _) => hasStringType(keyType) || hasStringType(valueType) case ArrayType(elementType, _) => hasStringType(elementType) case AbstractArrayType(elementType) => hasStringType(elementType) case TypeCollection(typeCollection) => @@ -202,12 +212,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi }).toArray val toSkip = List( - "map_zip_with", - "transform_keys", "session_window", // has too complex inputType - "transform_values", - "reduce", - "parse_url", // Parse URL is using wrong concepts + "parse_url", // Parse URL is using wrong concepts, not related to ExpectsInputTypes "hex" // this is fine ) // scalastyle:off println @@ -216,6 +222,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi println("Number of extracted expressions of relevance: " + (funInfos.length - toSkip.length)) // scalastyle:on println for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + // scalastyle:off println + println(f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) From 29bb4008924ba601c8c6a2d07a4b638226bde692 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 6 Jun 2024 09:18:49 +0200 Subject: [PATCH 15/30] Add support for StructType --- .../spark/sql/CollationExpressionWalkerSuite.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 1c29b64120063..89cbd4dd45f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, GenericInternalRow, Literal} import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, MapType, NumericType, StringType, StructType, TypeCollection} @@ -133,6 +133,13 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val key = generateLiterals(keyType, collationType) val value = generateLiterals(valueType, collationType) Literal.create(Map(key -> value)) + case StructType => + Literal.create((generateLiterals(StringTypeAnyCollation, collationType), + generateLiterals(StringTypeAnyCollation, collationType))) + case StructType(fields) => + Literal.create(new GenericInternalRow( + fields.map(f => generateLiterals(f.dataType, collationType).asInstanceOf[Any])), + StructType(fields)) } /** @@ -151,6 +158,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case AbstractArrayType(elementType) => hasStringType(elementType) case TypeCollection(typeCollection) => typeCollection.exists(hasStringType) + case StructType => true + case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType)) case _ => false } } @@ -212,7 +221,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi }).toArray val toSkip = List( - "session_window", // has too complex inputType "parse_url", // Parse URL is using wrong concepts, not related to ExpectsInputTypes "hex" // this is fine ) From 55f84daa8d5cc411af8cebefcb099126a29862af Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 6 Jun 2024 09:54:34 +0200 Subject: [PATCH 16/30] Remove unnecessary prints --- .../spark/sql/CollationExpressionWalkerSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 89cbd4dd45f5d..da0350a812e4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -61,10 +61,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry: Any, collationType: CollationType): Any = inputEntry match { - case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => Literal.create("1") + case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => + generateLiterals(StringTypeAnyCollation, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(Literal.create("1"), Literal.create("2")) - case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Expression]]) => None + Seq(generateLiterals(StringTypeAnyCollation, collationType), + generateLiterals(StringTypeAnyCollation, collationType)) + case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType case st: Class[_] if st.isAssignableFrom(classOf[StructType]) => StructType @@ -213,7 +215,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case types: ExpectsInputTypes => expectsExpressionCounter = expectsExpressionCounter + 1 val inputTypes = types.inputTypes - inputTypes.exists(hasStringType) + inputTypes.exists(hasStringType) || inputTypes.isEmpty } } catch { case _: Throwable => false @@ -230,8 +232,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi println("Number of extracted expressions of relevance: " + (funInfos.length - toSkip.length)) // scalastyle:on println for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - // scalastyle:off println - println(f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) From ba90ca5ca094f39c9c5bbce48ddec090c29ef23e Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 6 Jun 2024 10:17:29 +0200 Subject: [PATCH 17/30] Improve comment --- .../org/apache/spark/sql/CollationExpressionWalkerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index da0350a812e4a..850fa9328a7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -223,7 +223,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi }).toArray val toSkip = List( - "parse_url", // Parse URL is using wrong concepts, not related to ExpectsInputTypes + "parse_url", // Parse URL cannot be generalized with ExpectInputTypes "hex" // this is fine ) // scalastyle:off println From 7017d801ee2b30e3731e4ce4294b603e04b6afe8 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 6 Jun 2024 10:20:30 +0200 Subject: [PATCH 18/30] Improve comment --- .../org/apache/spark/sql/CollationExpressionWalkerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 850fa9328a7b3..02ab499683396 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -224,7 +224,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val toSkip = List( "parse_url", // Parse URL cannot be generalized with ExpectInputTypes - "hex" // this is fine + "hex" // Different inputs affect conversion ) // scalastyle:off println println("Total number of expression: " + expressionCounter) From 497baa50512f6d41bc4ef07da15847f3b7d42b1a Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 10 Jun 2024 09:02:32 +0200 Subject: [PATCH 19/30] Add example walker --- .../sql/CollationExpressionWalkerSuite.scala | 108 ++++++++++++++---- 1 file changed, 88 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 02ab499683396..7c06bc8f2d3f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, GenericInternalRow, Literal} +import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, MapType, NumericType, StringType, StructType, TypeCollection} @@ -177,19 +178,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } - test("SPARK-48280: Expression Walker for Test") { - // This test does following: - // 1) Take all expressions - // 2) Find the ones that have at least one argument of StringType - // 3) Use reflection to create an instance of the expression using first constructor - // (test other as well). - // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) - // 5) Run eval against literals with strings under: - // a) UTF8_BINARY, "dummy string" as input. - // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. - // 6) Check if both expressions throw an exception. - // 7) If no exception, check if the result is the same. - // 8) There is a list of allowed expressions that can differ (e.g. hex) + def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 var expectsExpressionCounter = 0; val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => @@ -226,11 +215,30 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "parse_url", // Parse URL cannot be generalized with ExpectInputTypes "hex" // Different inputs affect conversion ) - // scalastyle:off println - println("Total number of expression: " + expressionCounter) - println("Total number of expression that expect input: " + expectsExpressionCounter) - println("Number of extracted expressions of relevance: " + (funInfos.length - toSkip.length)) - // scalastyle:on println + + logInfo("Total number of expression: " + expressionCounter) + logInfo("Total number of expression that expect input: " + expectsExpressionCounter) + logInfo("Number of extracted expressions of relevance: " + (funInfos.length - toSkip.length)) + + (funInfos, toSkip) + } + + test("SPARK-48280: Expression Evaluator Test") { + // This test does following: + // 1) Take all expressions + // 2) Find the ones that have at least one argument of StringType + // 3) Use reflection to create an instance of the expression using first constructor + // (test other as well). + // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) + // 5) Run eval against literals with strings under: + // a) UTF8_BINARY, "dummy string" as input. + // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. + // 6) Check if both expressions throw an exception. + // 7) If no exception, check if the result is the same. + // 8) There is a list of allowed expressions that can differ (e.g. hex) + + val (funInfos, toSkip) = extractRelevantExpressions() + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head @@ -291,4 +299,64 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } } + + test("SPARK-48280: Expression Walker for SQL query examples") { + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + } + + // If expression is expected to return different results, it needs to be skipped + val toSkip = List( + // need to skip as these give timestamp/time related output + "current_timestamp", + "unix_timestamp", + "localtimestamp", + "now", + // need to skip as plans differ in STRING <-> STRING COLLATE UTF8_BINARY_LCASE + "current_timezone", + "schema_of_variant", + // need to skip as result is expected to differ + "collation", + "contains", + "aes_encrypt", + "translate", + "replace", + "grouping", + "grouping_id", + // need to skip as these are random functions + "rand", + "random", + "randn", + "uuid", + "shuffle", + // other functions which are not yet supported + "date_sub", + "date_add", + "dateadd", + "window", + "window_time", + "session_window", + "reflect", + "try_reflect", + "levenshtein", + "java_method" + ) + + for (funInfo <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + funInfo.getName) + for (m <- "> .*;".r.findAllIn(funInfo.getExamples)) { + try { + val resultUTF8 = sql(m.substring(2)) + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY_LCASE") { + val resultUTF8Lcase = sql(m.substring(2)) + assert(resultUTF8.collect() === resultUTF8Lcase.collect()) + } + } + catch { + case e: SparkRuntimeException => assert(e.getErrorClass == "USER_RAISED_EXCEPTION") + case other: Throwable => other + } + } + } + } } From 4e7b611129cb12146f3dbeb2b0dae1f96f1369a0 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 11 Jun 2024 19:40:00 +0200 Subject: [PATCH 20/30] Add new test --- .../sql/CollationExpressionWalkerSuite.scala | 121 +++++++++++++++++- 1 file changed, 115 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 7c06bc8f2d3f9..558421fa8c7e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql import org.apache.spark.{SparkFunSuite, SparkRuntimeException} -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{CreateArray, EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, MapType, NumericType, StringType, StructType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -52,6 +52,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } + def generateDataAsStrings( + inputEntry: Seq[AbstractDataType], + collationType: CollationType): Seq[Any] = { + inputEntry.map(generateInputAsString(_, collationType))) + } + /** * Helper function to generate single entry of data. * @param inputEntry - Single input entry that requires generation @@ -65,8 +71,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => generateLiterals(StringTypeAnyCollation, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - Seq(generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType)) + CreateArray(Seq(generateLiterals(StringTypeAnyCollation, collationType), + generateLiterals(StringTypeAnyCollation, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -100,7 +106,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case _: DecimalType => Literal(new Decimal) case IntegerType | NumericType => Literal(1) case LongType => Literal(1L) - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => + case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => Literal.create("dummy string", StringType("UTF8_BINARY")) @@ -145,6 +151,51 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi StructType(fields)) } + def generateInputAsString( + inputType: AbstractDataType, + collationType: CollationType): String = + inputType match { + // TODO: Try to make this a bit more random. + case AnyTimestampType => "'2009-07-30 12:58:59'" + case BinaryType => "X'0'" + case BooleanType => "true" + case _: DatetimeType => "date'2016-04-08'" + case _: DecimalType => "0.0" + case IntegerType | NumericType => "1" + case LongType => Literal(1L) + case _: StringType | AnyDataType | _: AbstractStringType => + collationType match { + case Utf8Binary => "dummy string" + case Utf8BinaryLcase => "DuMmY sTrInG" + } + case TypeCollection(typeCollection) => + val strTypes = typeCollection.filter(hasStringType) + if (strTypes.isEmpty) { + // Take first type + generateInputAsString(typeCollection.head, collationType) + } else { + // Take first string type + generateInputAsString(strTypes.head, collationType) + } + case AbstractArrayType(elementType) => + "array(" + generateInputAsString(elementType, collationType) + ")" + case ArrayType(elementType, _) => + "array(" + generateInputAsString(elementType, collationType) + ")" + case ArrayType => + "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + case MapType => + "map(" + generateInputAsString(StringTypeAnyCollation,collationType) + ", " + + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + case MapType(keyType, valueType, _) => + "map(" + generateInputAsString(keyType, collationType) + ", " + + generateInputAsString(valueType, collationType) + ")" + case StructType => + "struct(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + + generateInputAsString(StringTypeAnyCollation, collationType) +")" + case StructType(fields) => + "named_struct(" + ")" + fields.map(f => f.) + } + /** * Helper function to extract types of relevance * @param inputType @@ -223,7 +274,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } - test("SPARK-48280: Expression Evaluator Test") { + test("SPARK-48280: Expression Walker for expression evaluation") { // This test does following: // 1) Take all expressions // 2) Find the ones that have at least one argument of StringType @@ -300,6 +351,64 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } + test("SPARK-48280: Expression Walker for plan generation of expressions") { + + var (funInfos, toSkip) = extractRelevantExpressions() + toSkip = + "element_at" :: + "try_element_at" :: + "reduce" :: + "aggregate" :: + "array_intersect" :: + toSkip + + var typeList = List( + "date", + "map,string>", + "string", + "array", + "float", + "smallint", + "map", + "map,struct<>>", + "struct", + "bigint", + "array>", + "timestamp", + "array>", + "array,value:struct<>>>", + "struct", + "timestamp_ntz", + "decimal(0,0)", + "double", + "int", + "map>", + "boolean", + "struct<>", + "binary" + ) + val dt = collection.mutable.Set[String]() + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + f.getName) + val cl = Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors.head + val params = headConstructor.getParameters.map(p => p.getType) + val args = generateData(params.toSeq, Utf8Binary) + val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] + val inputTypes = expr.inputTypes + + val inputDataUtf8Binary = { + generateData(replaceExpressions(inputTypes, params.toSeq), Utf8BinaryLcase) + } + println(inputDataUtf8Binary.map(_.toString)) + val instanceUtf8Binary = { + headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] + } + dt += instanceUtf8Binary.dataType.simpleString + } + println(dt) + } + test("SPARK-48280: Expression Walker for SQL query examples") { val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) From 8cdb7ada29a6810d1083e1b0c89171cbdcb09ad0 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 13 Jun 2024 08:47:19 +0200 Subject: [PATCH 21/30] Add codeGen test --- .../sql/CollationExpressionWalkerSuite.scala | 231 +++++++++++++----- 1 file changed, 176 insertions(+), 55 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 558421fa8c7e3..284d175501e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkFunSuite, SparkRuntimeException} -import org.apache.spark.sql.catalyst.expressions.{CreateArray, EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, CreateArray, EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.test.SharedSparkSession @@ -55,7 +55,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi def generateDataAsStrings( inputEntry: Seq[AbstractDataType], collationType: CollationType): Seq[Any] = { - inputEntry.map(generateInputAsString(_, collationType))) + inputEntry.map(generateInputAsString(_, collationType)) } /** @@ -111,7 +111,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case Utf8Binary => Literal.create("dummy string", StringType("UTF8_BINARY")) case Utf8BinaryLcase => - Literal.create("DuMmY sTrInG", StringType("UTF8_BINARY_LCASE")) + Literal.create("DuMmY sTrInG", StringType("UTF8_LCASE")) } case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) @@ -156,17 +156,17 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): String = inputType match { // TODO: Try to make this a bit more random. - case AnyTimestampType => "'2009-07-30 12:58:59'" + case AnyTimestampType => "TIMESTAMP'2009-07-30 12:58:59'" case BinaryType => "X'0'" - case BooleanType => "true" + case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" case _: DecimalType => "0.0" case IntegerType | NumericType => "1" - case LongType => Literal(1L) + case LongType => "1" case _: StringType | AnyDataType | _: AbstractStringType => collationType match { - case Utf8Binary => "dummy string" - case Utf8BinaryLcase => "DuMmY sTrInG" + case Utf8Binary => "'dummy string' COLLATE UTF8_BINARY" + case Utf8BinaryLcase => "'DuMmY sTrInG' COLLATE UTF8_LCASE" } case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) @@ -184,16 +184,63 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType => "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeAnyCollation,collationType) + ", " + + "map(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" case StructType => - "struct(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + - generateInputAsString(StringTypeAnyCollation, collationType) +")" + "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + + ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" case StructType(fields) => - "named_struct(" + ")" + fields.map(f => f.) + "named_struct(" + fields.map(f => "'" + f.name + "', " + + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" + } + + def generateInputTypeAsStrings( + inputType: AbstractDataType, + collationType: CollationType): String = + inputType match { + case AnyTimestampType => "TIMESTAMP" + case BinaryType => "BINARY" + case BooleanType => "BOOLEAN" + case _: DatetimeType => "DATE" + case _: DecimalType => "DECIMAL(2, 1)" + case IntegerType | NumericType => "INT" + case LongType => "BIGINT" + case _: StringType | AnyDataType | _: AbstractStringType => + collationType match { + case Utf8Binary => "STRING" + case Utf8BinaryLcase => "STRING COLLATE UTF8_LCASE" + } + case TypeCollection(typeCollection) => + val strTypes = typeCollection.filter(hasStringType) + if (strTypes.isEmpty) { + // Take first type + generateInputTypeAsStrings(typeCollection.head, collationType) + } else { + // Take first string type + generateInputTypeAsStrings(strTypes.head, collationType) + } + case AbstractArrayType(elementType) => + "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" + case ArrayType(elementType, _) => + "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" + case ArrayType => + "array<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + case MapType => + "map<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ", " + + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + case MapType(keyType, valueType, _) => + "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + + generateInputTypeAsStrings(valueType, collationType) + ">" + case StructType => + "struct" + case StructType(fields) => + "named_struct<" + fields.map(f => "'" + f.name + "', " + + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" } /** @@ -255,7 +302,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case types: ExpectsInputTypes => expectsExpressionCounter = expectsExpressionCounter + 1 val inputTypes = types.inputTypes - inputTypes.exists(hasStringType) || inputTypes.isEmpty + inputTypes.exists(hasStringType) } } catch { case _: Throwable => false @@ -351,43 +398,20 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } - test("SPARK-48280: Expression Walker for plan generation of expressions") { + test("SPARK-48280: Expression Walker for codeGen generation") { var (funInfos, toSkip) = extractRelevantExpressions() - toSkip = - "element_at" :: - "try_element_at" :: - "reduce" :: - "aggregate" :: - "array_intersect" :: - toSkip - - var typeList = List( - "date", - "map,string>", - "string", - "array", - "float", - "smallint", - "map", - "map,struct<>>", - "struct", - "bigint", - "array>", - "timestamp", - "array>", - "array,value:struct<>>>", - "struct", - "timestamp_ntz", - "decimal(0,0)", - "double", - "int", - "map>", - "boolean", - "struct<>", - "binary" + toSkip = toSkip ++ List( + // Problem caught with other tests already + "map_from_arrays", + // These expressions are not called as functions + "lead", + "nth_value", + // Failing asserts + "session_window", + "ascii", + "to_xml" ) - val dt = collection.mutable.Set[String]() for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) @@ -397,16 +421,113 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] val inputTypes = expr.inputTypes - val inputDataUtf8Binary = { - generateData(replaceExpressions(inputTypes, params.toSeq), Utf8BinaryLcase) - } - println(inputDataUtf8Binary.map(_.toString)) - val instanceUtf8Binary = { - headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] + withTable("tbl", "tbl_lcase") { + sql("CREATE TABLE tbl (" + + inputTypes.zipWithIndex + .map(it => "col" + + it._2.toString + " " + + generateInputTypeAsStrings(it._1, Utf8Binary)).mkString(", ") + + ") USING PARQUET") + sql("INSERT INTO tbl VALUES (" + + inputTypes.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ")") + + sql("CREATE TABLE tbl_lcase (" + + inputTypes.zipWithIndex + .map(it => "col" + + it._2.toString + " " + + generateInputTypeAsStrings(it._1, Utf8BinaryLcase)).mkString(", ") + + ") USING PARQUET") + sql("INSERT INTO tbl_lcase VALUES (" + + inputTypes.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ")") + + val utf8BinaryResult = try { + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ") FROM tbl") + } + }.getRows(1, 1) + None + } catch { + case e: Throwable => Some(e) + } + val utf8BinaryLcaseResult = try { + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ") FROM tbl_lcase") + } + }.getRows(1, 1) + None + } catch { + case e: Throwable => Some(e) + } + + assert(utf8BinaryResult.isDefined === utf8BinaryLcaseResult.isDefined) + + if (utf8BinaryResult.isEmpty) { + val utf8BinaryResult = + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ") FROM tbl") + } + } + val utf8BinaryLcaseResult = + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ") FROM tbl_lcase") + } + } + + val dt = utf8BinaryResult.schema.fields.head.dataType + + dt match { + case st if utf8BinaryResult != null && utf8BinaryLcaseResult != null && + hasStringType(st) => + // scalastyle:off caselocale + assert(utf8BinaryResult.getRows(1, 1).map(_.map(_.toLowerCase)) === + utf8BinaryLcaseResult.getRows(1, 1).map(_.map(_.toLowerCase))) + // scalastyle:on caselocale + case _ => + // scalastyle:off caselocale + assert(utf8BinaryResult.getRows(1, 1)(1) === + utf8BinaryLcaseResult.getRows(1, 1)(1)) + // scalastyle:on caselocale + } + } + else { + assert(utf8BinaryResult.get.getClass == utf8BinaryResult.get.getClass) + } } - dt += instanceUtf8Binary.dataType.simpleString } - println(dt) } test("SPARK-48280: Expression Walker for SQL query examples") { From 776dcbaee2c5481db839c142b7a03b7c386b1516 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 13 Jun 2024 14:34:12 +0200 Subject: [PATCH 22/30] Fix test errors --- .../sql/CollationExpressionWalkerSuite.scala | 101 ++++++++++++------ 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 284d175501e14..a37820ebfe437 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql import org.apache.spark.{SparkFunSuite, SparkRuntimeException} -import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, CreateArray, EmptyRow, EvalMode, ExpectsInputTypes, Expression, ExpressionInfo, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType, StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, AnyTimestampType, ArrayType, BinaryType, BooleanType, DataType, DatetimeType, Decimal, DecimalType, IntegerType, LongType, MapType, NumericType, StringType, StructType, TypeCollection} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -100,12 +99,17 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputType match { // TODO: Try to make this a bit more random. case AnyTimestampType => Literal("2009-07-30 12:58:59") - case BinaryType => Literal(new Array[Byte](5)) + case BinaryType => collationType match { + case Utf8Binary => + Literal.create("dummy string".getBytes) + case Utf8BinaryLcase => + Literal.create("DuMmY sTrInG".getBytes) + } case BooleanType => Literal(true) - case _: DatetimeType => Literal(1L) + case _: DatetimeType => Literal(0L) case _: DecimalType => Literal(new Decimal) - case IntegerType | NumericType => Literal(1) - case LongType => Literal(1L) + case IntegerType | NumericType => Literal(0) + case LongType => Literal(0L) case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => @@ -157,12 +161,16 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputType match { // TODO: Try to make this a bit more random. case AnyTimestampType => "TIMESTAMP'2009-07-30 12:58:59'" - case BinaryType => "X'0'" + case BinaryType => + collationType match { + case Utf8Binary => "Cast('dummy string' collate utf8_binary as BINARY)" + case Utf8BinaryLcase => "Cast('DuMmY sTrInG' collate utf8_lcase as BINARY)" + } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" case _: DecimalType => "0.0" - case IntegerType | NumericType => "1" - case LongType => "1" + case IntegerType | NumericType => "0" + case LongType => "0" case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => "'dummy string' COLLATE UTF8_BINARY" @@ -302,7 +310,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case types: ExpectsInputTypes => expectsExpressionCounter = expectsExpressionCounter + 1 val inputTypes = types.inputTypes - inputTypes.exists(hasStringType) + inputTypes.exists(it => hasStringType(it) || it.isInstanceOf[BinaryType]) } } catch { case _: Throwable => false @@ -311,7 +319,17 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val toSkip = List( "parse_url", // Parse URL cannot be generalized with ExpectInputTypes - "hex" // Different inputs affect conversion + "collation", // Expected to return different collation names + // Different inputs affect conversion + "hex", + "md5", + "sha1", + "unbase64", + "base64", + "sha2", + "sha", + "crc32", + "ascii" ) logInfo("Total number of expression: " + expressionCounter) @@ -349,14 +367,18 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi generateData(replaceExpressions(inputTypes, params.toSeq), Utf8Binary) val instanceUtf8Binary = headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] - val inputDataLcase = generateData(replaceExpressions(inputTypes, params.toSeq), Utf8BinaryLcase) val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] val exceptionUtfBinary = { try { - instanceUtf8Binary.eval(EmptyRow) + instanceUtf8Binary match { + case replaceable: RuntimeReplaceable => + replaceable.replacement.eval(EmptyRow) + case _ => + instanceUtf8Binary.eval(EmptyRow) + } None } catch { case e: Throwable => Some(e) @@ -365,7 +387,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val exceptionLcase = { try { - instanceLcase.eval(EmptyRow) + instanceLcase match { + case replaceable: RuntimeReplaceable => + replaceable.replacement.eval(EmptyRow) + case _ => + instanceLcase.eval(EmptyRow) + } None } catch { case e: Throwable => Some(e) @@ -376,20 +403,28 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) if (exceptionUtfBinary.isEmpty) { - val resUtf8Binary = instanceUtf8Binary.eval(EmptyRow) - val resUtf8Lcase = instanceLcase.eval(EmptyRow) + val resUtf8Binary = instanceUtf8Binary match { + case replaceable: RuntimeReplaceable => + replaceable.replacement.eval(EmptyRow) + case _ => + instanceUtf8Binary.eval(EmptyRow) + } + val resUtf8Lcase = instanceLcase match { + case replaceable: RuntimeReplaceable => + replaceable.replacement.eval(EmptyRow) + case _ => + instanceLcase.eval(EmptyRow) + } val dt = instanceLcase.dataType dt match { - case _: StringType if resUtf8Lcase != null && resUtf8Lcase != null => - assert(resUtf8Binary.isInstanceOf[UTF8String]) - assert(resUtf8Lcase.isInstanceOf[UTF8String]) + case st if resUtf8Lcase != null && resUtf8Lcase != null && hasStringType(st) => // scalastyle:off caselocale - assert(resUtf8Binary.asInstanceOf[UTF8String].toLowerCase.binaryEquals( - resUtf8Lcase.asInstanceOf[UTF8String].toLowerCase)) - // scalastyle:on caselocale - case _ => resUtf8Lcase === resUtf8Binary + assert(resUtf8Binary.toString.toLowerCase === resUtf8Lcase.toString.toLowerCase) + // scalastyle:on caselocale + case _ => + assert(resUtf8Lcase === resUtf8Binary) } } else { @@ -407,13 +442,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi // These expressions are not called as functions "lead", "nth_value", - // Failing asserts "session_window", - "ascii", + // Unexpected to fail "to_xml" ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) @@ -454,7 +487,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + ") FROM tbl") } - }.getRows(1, 1) + }.getRows(1, 0) None } catch { case e: Throwable => Some(e) @@ -471,7 +504,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + ") FROM tbl_lcase") } - }.getRows(1, 1) + }.getRows(1, 0) None } catch { case e: Throwable => Some(e) @@ -513,13 +546,13 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case st if utf8BinaryResult != null && utf8BinaryLcaseResult != null && hasStringType(st) => // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 1).map(_.map(_.toLowerCase)) === - utf8BinaryLcaseResult.getRows(1, 1).map(_.map(_.toLowerCase))) + assert(utf8BinaryResult.getRows(1, 0).map(_.map(_.toLowerCase)) === + utf8BinaryLcaseResult.getRows(1, 0).map(_.map(_.toLowerCase))) // scalastyle:on caselocale case _ => // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 1)(1) === - utf8BinaryLcaseResult.getRows(1, 1)(1)) + assert(utf8BinaryResult.getRows(1, 0)(1) === + utf8BinaryLcaseResult.getRows(1, 0)(1)) // scalastyle:on caselocale } } From ced55006384d4284ad1be8cc4b3afcabd8a96068 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Fri, 14 Jun 2024 04:50:13 +0200 Subject: [PATCH 23/30] Add new test --- .../sql/CollationExpressionWalkerSuite.scala | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index a37820ebfe437..47b61df5e65e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -51,6 +51,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } + /** + * Helper function to generate single entry of data as a string. + * @param inputEntry - Single input entry that requires generation + * @param collationType - Flag defining collation type to use + * @return + */ def generateDataAsStrings( inputEntry: Seq[AbstractDataType], collationType: CollationType): Seq[Any] = { @@ -155,6 +161,13 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi StructType(fields)) } + /** + * Helper function to generate single input as a string from the given type. + * + * @param inputType - Single input type that requires generation + * @param collationType - Flag defining collation type to use + * @return + */ def generateInputAsString( inputType: AbstractDataType, collationType: CollationType): String = @@ -205,6 +218,13 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" } + /** + * Helper function to generate single input type as string from the given type. + * + * @param inputType - Single input type that requires generation + * @param collationType - Flag defining collation type to use + * @return + */ def generateInputTypeAsStrings( inputType: AbstractDataType, collationType: CollationType): String = @@ -273,6 +293,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } + /** + * Helper function to replace expected parameters with expected input types. + * @param inputTypes - Input types generated by ExpectsInputType.inputTypes + * @param params - Parameters that are read from expression info + * @return + */ def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { (inputTypes, params) match { case (Nil, mparams) => mparams @@ -284,6 +310,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } + /** + * Helper method to extract relevant expressions that can be walked over. + * @return + */ def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 var expectsExpressionCounter = 0; @@ -339,6 +369,13 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * This test does following: + * 1) Take all expressions + * 2) Find the ones that have at least one argument of StringType + * 3) Use reflection to create an instance of the expression using first constructor + * 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) + */ test("SPARK-48280: Expression Walker for expression evaluation") { // This test does following: // 1) Take all expressions @@ -563,6 +600,136 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } + test("SPARK-48280: Expression Walker for codeGen generation with photonization") { + + var (funInfos, toSkip) = extractRelevantExpressions() + toSkip = toSkip ++ List( + // Problem caught with other tests already + "map_from_arrays", + // These expressions are not called as functions + "lead", + "nth_value", + "session_window", + // Unexpected to fail + "to_xml" + ) + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + val cl = Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors.head + val params = headConstructor.getParameters.map(p => p.getType) + val args = generateData(params.toSeq, Utf8Binary) + val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] + val inputTypes = expr.inputTypes + + withTable("tbl", "tbl_lcase") { + sql("CREATE TABLE tbl (" + + inputTypes.zipWithIndex + .map(it => "col" + + it._2.toString + " " + + generateInputTypeAsStrings(it._1, Utf8Binary)).mkString(", ") + + ") USING PARQUET") + sql("INSERT INTO tbl VALUES (" + + inputTypes.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ")") + + sql("CREATE TABLE tbl_lcase (" + + inputTypes.zipWithIndex + .map(it => "col" + + it._2.toString + " " + + generateInputTypeAsStrings(it._1, Utf8BinaryLcase)).mkString(", ") + + ") USING PARQUET") + sql("INSERT INTO tbl_lcase VALUES (" + + inputTypes.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ")") + + val utf8BinaryResult = try { + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ") FROM tbl") + } + }.getRows(1, 0) + None + } catch { + case e: Throwable => Some(e) + } + val utf8BinaryLcaseResult = try { + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ") FROM tbl_lcase") + } + }.getRows(1, 0) + None + } catch { + case e: Throwable => Some(e) + } + + assert(utf8BinaryResult.isDefined === utf8BinaryLcaseResult.isDefined) + + if (utf8BinaryResult.isEmpty) { + val utf8BinaryResult = + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + + ") FROM tbl") + } + } + val utf8BinaryLcaseResult = + if (expr.isInstanceOf[BinaryComparison]) { + sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") + } else { + if (inputTypes.size == 1) { + sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") + } + else { + sql("SELECT " + f.getName + "(col0, " + + inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + + ") FROM tbl_lcase") + } + } + + val dt = utf8BinaryResult.schema.fields.head.dataType + + dt match { + case st if utf8BinaryResult != null && utf8BinaryLcaseResult != null && + hasStringType(st) => + // scalastyle:off caselocale + assert(utf8BinaryResult.getRows(1, 0).map(_.map(_.toLowerCase)) === + utf8BinaryLcaseResult.getRows(1, 0).map(_.map(_.toLowerCase))) + // scalastyle:on caselocale + case _ => + // scalastyle:off caselocale + assert(utf8BinaryResult.getRows(1, 0)(1) === + utf8BinaryLcaseResult.getRows(1, 0)(1)) + // scalastyle:on caselocale + } + } + else { + assert(utf8BinaryResult.get.getClass == utf8BinaryResult.get.getClass) + } + } + } + } + test("SPARK-48280: Expression Walker for SQL query examples") { val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) From 9446fd151d9f4b0927ceef6a5478258c8e535c20 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 18 Jun 2024 11:07:18 +0200 Subject: [PATCH 24/30] Enable fixed expressions --- .../sql/CollationExpressionWalkerSuite.scala | 139 ------------------ 1 file changed, 139 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 47b61df5e65e8..73b517c9490d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -474,8 +474,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var (funInfos, toSkip) = extractRelevantExpressions() toSkip = toSkip ++ List( - // Problem caught with other tests already - "map_from_arrays", // These expressions are not called as functions "lead", "nth_value", @@ -600,136 +598,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } - test("SPARK-48280: Expression Walker for codeGen generation with photonization") { - - var (funInfos, toSkip) = extractRelevantExpressions() - toSkip = toSkip ++ List( - // Problem caught with other tests already - "map_from_arrays", - // These expressions are not called as functions - "lead", - "nth_value", - "session_window", - // Unexpected to fail - "to_xml" - ) - for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) - val headConstructor = cl.getConstructors.head - val params = headConstructor.getParameters.map(p => p.getType) - val args = generateData(params.toSeq, Utf8Binary) - val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] - val inputTypes = expr.inputTypes - - withTable("tbl", "tbl_lcase") { - sql("CREATE TABLE tbl (" + - inputTypes.zipWithIndex - .map(it => "col" + - it._2.toString + " " + - generateInputTypeAsStrings(it._1, Utf8Binary)).mkString(", ") + - ") USING PARQUET") - sql("INSERT INTO tbl VALUES (" + - inputTypes.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ")") - - sql("CREATE TABLE tbl_lcase (" + - inputTypes.zipWithIndex - .map(it => "col" + - it._2.toString + " " + - generateInputTypeAsStrings(it._1, Utf8BinaryLcase)).mkString(", ") + - ") USING PARQUET") - sql("INSERT INTO tbl_lcase VALUES (" + - inputTypes.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ")") - - val utf8BinaryResult = try { - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ") FROM tbl") - } - }.getRows(1, 0) - None - } catch { - case e: Throwable => Some(e) - } - val utf8BinaryLcaseResult = try { - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ") FROM tbl_lcase") - } - }.getRows(1, 0) - None - } catch { - case e: Throwable => Some(e) - } - - assert(utf8BinaryResult.isDefined === utf8BinaryLcaseResult.isDefined) - - if (utf8BinaryResult.isEmpty) { - val utf8BinaryResult = - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ") FROM tbl") - } - } - val utf8BinaryLcaseResult = - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ") FROM tbl_lcase") - } - } - - val dt = utf8BinaryResult.schema.fields.head.dataType - - dt match { - case st if utf8BinaryResult != null && utf8BinaryLcaseResult != null && - hasStringType(st) => - // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 0).map(_.map(_.toLowerCase)) === - utf8BinaryLcaseResult.getRows(1, 0).map(_.map(_.toLowerCase))) - // scalastyle:on caselocale - case _ => - // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 0)(1) === - utf8BinaryLcaseResult.getRows(1, 0)(1)) - // scalastyle:on caselocale - } - } - else { - assert(utf8BinaryResult.get.getClass == utf8BinaryResult.get.getClass) - } - } - } - } - test("SPARK-48280: Expression Walker for SQL query examples") { val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) @@ -760,15 +628,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "uuid", "shuffle", // other functions which are not yet supported - "date_sub", - "date_add", - "dateadd", - "window", - "window_time", - "session_window", "reflect", "try_reflect", - "levenshtein", "java_method" ) From 9902b05346f8fe0d3ff489e697615064cb2c9650 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 20 Jun 2024 12:57:52 +0200 Subject: [PATCH 25/30] Polish code and improve tests --- .../sql/CollationExpressionWalkerSuite.scala | 209 +++++++++--------- 1 file changed, 104 insertions(+), 105 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 73b517c9490d6..81401d790a3ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.internal.types._ import org.apache.spark.sql.test.SharedSparkSession @@ -114,7 +115,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BooleanType => Literal(true) case _: DatetimeType => Literal(0L) case _: DecimalType => Literal(new Decimal) - case IntegerType | NumericType => Literal(0) + case _: DoubleType => Literal(0.0) + case IntegerType | NumericType | IntegralType => Literal(0) case LongType => Literal(0L) case _: StringType | AnyDataType | _: AbstractStringType => collationType match { @@ -123,6 +125,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case Utf8BinaryLcase => Literal.create("DuMmY sTrInG", StringType("UTF8_LCASE")) } + case VariantType => ParseJson(generateLiterals(StringTypeAnyCollation, collationType)) case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) if (strTypes.isEmpty) { @@ -153,12 +156,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val value = generateLiterals(valueType, collationType) Literal.create(Map(key -> value)) case StructType => - Literal.create((generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType))) - case StructType(fields) => - Literal.create(new GenericInternalRow( - fields.map(f => generateLiterals(f.dataType, collationType).asInstanceOf[Any])), - StructType(fields)) + CreateNamedStruct( + Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), + Literal("end"), generateLiterals(StringTypeAnyCollation, collationType))) } /** @@ -182,13 +182,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" case _: DecimalType => "0.0" - case IntegerType | NumericType => "0" + case _: DoubleType => "0.0" + case IntegerType | NumericType | IntegralType => "0" case LongType => "0" case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => "'dummy string' COLLATE UTF8_BINARY" case Utf8BinaryLcase => "'DuMmY sTrInG' COLLATE UTF8_LCASE" } + case VariantType => s"parse_json('1')" case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) if (strTypes.isEmpty) { @@ -234,13 +236,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BooleanType => "BOOLEAN" case _: DatetimeType => "DATE" case _: DecimalType => "DECIMAL(2, 1)" - case IntegerType | NumericType => "INT" + case _: DoubleType => "DOUBLE" + case IntegerType | NumericType | IntegralType => "INT" case LongType => "BIGINT" case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => "STRING" case Utf8BinaryLcase => "STRING COLLATE UTF8_LCASE" } + case VariantType => "VARIANT" case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) if (strTypes.isEmpty) { @@ -337,10 +341,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi try { val expr = headConstructor.newInstance(args: _*) expr match { - case types: ExpectsInputTypes => + case expTypes: ExpectsInputTypes => expectsExpressionCounter = expectsExpressionCounter + 1 - val inputTypes = types.inputTypes - inputTypes.exists(it => hasStringType(it) || it.isInstanceOf[BinaryType]) + val inputTypes = expTypes.inputTypes + inputTypes.exists(it => hasStringType(it)) || + (inputTypes.nonEmpty && hasStringType(expTypes.dataType)) } } catch { case _: Throwable => false @@ -369,30 +374,67 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * Helper function to generate string of an expression suitable for execution. + * @param expr - Expression that needs to be converted + * @param collationType - Defines explicit collation to use + * @return + */ + def transformExpressionToString(expr: ExpectsInputTypes, collationType: CollationType): String = { + if (expr.isInstanceOf[BinaryComparison]) { + "col0 " + expr.prettyName + " col1" + } else { + if (expr.inputTypes.size == 1) { + expr.prettyName + "(col0)" + } + else { + expr.prettyName + "(col0, " + + expr.inputTypes.tail.map(generateInputAsString(_, collationType)).mkString(", ") + ")" + } + } + } + + /** + * Helper function to generate input data for the dataframe. + * @param inputTypes - Column types that need to be generated + * @param collationType - Defines explicit collation to use + * @return + */ + def generateTableData( + inputTypes: Seq[AbstractDataType], + collationType: CollationType): DataFrame = { + val tblName = collationType match { + case Utf8Binary => "tbl" + case Utf8BinaryLcase => "tbl_lcase" + } + + sql(s"CREATE TABLE $tblName (" + + inputTypes.zipWithIndex + .map(it => "col" + + it._2.toString + " " + + generateInputTypeAsStrings(it._1, collationType)).mkString(", ") + + ") USING PARQUET") + + sql(s"INSERT INTO $tblName VALUES (" + + inputTypes.map(generateInputAsString(_, collationType)).mkString(", ") + + ")") + + sql(s"SELECT * FROM $tblName") + } + /** * This test does following: - * 1) Take all expressions - * 2) Find the ones that have at least one argument of StringType - * 3) Use reflection to create an instance of the expression using first constructor - * 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) + * 1) Extract relevant expressions + * 2) Run evaluation on expressions with different inputs + * 3) Check if both expressions throw an exception + * 4) If no exception, check if the result is the same + * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for expression evaluation") { - // This test does following: - // 1) Take all expressions - // 2) Find the ones that have at least one argument of StringType - // 3) Use reflection to create an instance of the expression using first constructor - // (test other as well). - // 4) Check if the expression is of type ExpectsInputTypes (should make this a bit broader) - // 5) Run eval against literals with strings under: - // a) UTF8_BINARY, "dummy string" as input. - // b) UTF8_BINARY_LCASE, "DuMmY sTrInG" as input. - // 6) Check if both expressions throw an exception. - // 7) If no exception, check if the result is the same. - // 8) There is a list of allowed expressions that can differ (e.g. hex) - val (funInfos, toSkip) = extractRelevantExpressions() for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors.head val params = headConstructor.getParameters.map(p => p.getType) @@ -432,7 +474,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } None } catch { - case e: Throwable => Some(e) + case e: Throwable => + println(e.getMessage) + Some(e) } } @@ -471,15 +515,22 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } test("SPARK-48280: Expression Walker for codeGen generation") { - + /** + * This test does following: + * 1) Extract relevant expressions + * 2) Run dataframe select on expressions with different inputs + * 3) Check if both expressions throw an exception + * 4) If no exception, check if the result is the same + * 5) Otherwise, check if exceptions are the same + */ var (funInfos, toSkip) = extractRelevantExpressions() toSkip = toSkip ++ List( + // Known to be faulty + "to_xml", // These expressions are not called as functions "lead", "nth_value", - "session_window", - // Unexpected to fail - "to_xml" + "session_window" ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { val cl = Utils.classForName(f.getClassName) @@ -487,59 +538,22 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val params = headConstructor.getParameters.map(p => p.getType) val args = generateData(params.toSeq, Utf8Binary) val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] - val inputTypes = expr.inputTypes withTable("tbl", "tbl_lcase") { - sql("CREATE TABLE tbl (" + - inputTypes.zipWithIndex - .map(it => "col" + - it._2.toString + " " + - generateInputTypeAsStrings(it._1, Utf8Binary)).mkString(", ") + - ") USING PARQUET") - sql("INSERT INTO tbl VALUES (" + - inputTypes.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ")") - - sql("CREATE TABLE tbl_lcase (" + - inputTypes.zipWithIndex - .map(it => "col" + - it._2.toString + " " + - generateInputTypeAsStrings(it._1, Utf8BinaryLcase)).mkString(", ") + - ") USING PARQUET") - sql("INSERT INTO tbl_lcase VALUES (" + - inputTypes.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ")") + + val utf8_df = generateTableData(expr.inputTypes, Utf8Binary) + val utf8_lcase_df = generateTableData(expr.inputTypes, Utf8BinaryLcase) val utf8BinaryResult = try { - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ") FROM tbl") - } - }.getRows(1, 0) + utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary)) + .getRows(1, 0) None } catch { case e: Throwable => Some(e) } val utf8BinaryLcaseResult = try { - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ") FROM tbl_lcase") - } - }.getRows(1, 0) + utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8BinaryLcase)) + .getRows(1, 0) None } catch { case e: Throwable => Some(e) @@ -549,31 +563,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi if (utf8BinaryResult.isEmpty) { val utf8BinaryResult = - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8Binary)).mkString(", ") + - ") FROM tbl") - } - } + utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary)) val utf8BinaryLcaseResult = - if (expr.isInstanceOf[BinaryComparison]) { - sql("SELECT " + "(col0 " + f.getName + "col1) FROM tbl_lcase") - } else { - if (inputTypes.size == 1) { - sql("SELECT " + f.getName + "(col0) FROM tbl_lcase") - } - else { - sql("SELECT " + f.getName + "(col0, " + - inputTypes.tail.map(generateInputAsString(_, Utf8BinaryLcase)).mkString(", ") + - ") FROM tbl_lcase") - } - } + utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8BinaryLcase)) val dt = utf8BinaryResult.schema.fields.head.dataType @@ -599,6 +591,14 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } test("SPARK-48280: Expression Walker for SQL query examples") { + /** + * This test does following: + * 1) Extract all expressions + * 2) Run example queries for different session level default collations + * 3) Check if both expressions throw an exception + * 4) If no exception, check if the result is the same + * 5) Otherwise, check if exceptions are the same + */ val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) } @@ -634,18 +634,17 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi ) for (funInfo <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + funInfo.getName) for (m <- "> .*;".r.findAllIn(funInfo.getExamples)) { try { val resultUTF8 = sql(m.substring(2)) - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY_LCASE") { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { val resultUTF8Lcase = sql(m.substring(2)) assert(resultUTF8.collect() === resultUTF8Lcase.collect()) } } catch { case e: SparkRuntimeException => assert(e.getErrorClass == "USER_RAISED_EXCEPTION") - case other: Throwable => other + case other: Throwable => throw other } } } From e75ff5caffa3d5c75a4d655976de368194bc6684 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 1 Jul 2024 13:57:56 +0200 Subject: [PATCH 26/30] Incorporate changes --- .../sql/CollationExpressionWalkerSuite.scala | 142 +++++++++--------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 81401d790a3ae..cc25bccbde808 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.sql.Timestamp + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant.ParseJson @@ -40,7 +42,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case object Utf8BinaryLcase extends CollationType /** - * Helper function to generate all necesary parameters + * Helper function to generate all necessary parameters * * @param inputEntry - List of all input entries that need to be generated * @param collationType - Flag defining collation type to use @@ -105,7 +107,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): Expression = inputType match { // TODO: Try to make this a bit more random. - case AnyTimestampType => Literal("2009-07-30 12:58:59") + case AnyTimestampType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) case BinaryType => collationType match { case Utf8Binary => Literal.create("dummy string".getBytes) @@ -113,11 +115,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi Literal.create("DuMmY sTrInG".getBytes) } case BooleanType => Literal(true) - case _: DatetimeType => Literal(0L) - case _: DecimalType => Literal(new Decimal) - case _: DoubleType => Literal(0.0) - case IntegerType | NumericType | IntegralType => Literal(0) - case LongType => Literal(0L) + case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) + case _: DecimalType => Literal((new Decimal).set(5)) + case _: DoubleType => Literal(5.0) + case IntegerType | NumericType | IntegralType => Literal(5) + case LongType => Literal(5L) case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => @@ -125,7 +127,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case Utf8BinaryLcase => Literal.create("DuMmY sTrInG", StringType("UTF8_LCASE")) } - case VariantType => ParseJson(generateLiterals(StringTypeAnyCollation, collationType)) + case VariantType => collationType match { + case Utf8Binary => + ParseJson(Literal.create("{}", StringType("UTF8_BINARY"))) + case Utf8BinaryLcase => + ParseJson(Literal.create("{}", StringType("UTF8_LCASE"))) + } case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) if (strTypes.isEmpty) { @@ -181,16 +188,16 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" - case _: DecimalType => "0.0" - case _: DoubleType => "0.0" - case IntegerType | NumericType | IntegralType => "0" - case LongType => "0" + case _: DecimalType => "5.0" + case _: DoubleType => "5.0" + case IntegerType | NumericType | IntegralType => "5" + case LongType => "5L" case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => "'dummy string' COLLATE UTF8_BINARY" case Utf8BinaryLcase => "'DuMmY sTrInG' COLLATE UTF8_LCASE" } - case VariantType => s"parse_json('1')" + case VariantType => s"parse_json('{}')" case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) if (strTypes.isEmpty) { @@ -332,7 +339,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val cl = Utils.classForName(funInfo.getClassName) // dummy instance // Take first constructor. - val headConstructor = cl.getConstructors.head + val headConstructor = cl.getConstructors + .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -381,8 +389,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * @return */ def transformExpressionToString(expr: ExpectsInputTypes, collationType: CollationType): String = { - if (expr.isInstanceOf[BinaryComparison]) { - "col0 " + expr.prettyName + " col1" + if (expr.isInstanceOf[BinaryOperator]) { + "col0 " + expr.asInstanceOf[BinaryOperator].symbol + " col1" } else { if (expr.inputTypes.size == 1) { expr.prettyName + "(col0)" @@ -434,68 +442,60 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val (funInfos, toSkip) = extractRelevantExpressions() for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) - val headConstructor = cl.getConstructors.head + val headConstructor = cl.getConstructors + .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) val args = generateData(params.toSeq, Utf8Binary) val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] val inputTypes = expr.inputTypes val inputDataUtf8Binary = - generateData(replaceExpressions(inputTypes, params.toSeq), Utf8Binary) + generateData( + replaceExpressions(inputTypes, headConstructor.getParameters.map(p => p.getType).toSeq), + Utf8Binary + ) val instanceUtf8Binary = headConstructor.newInstance(inputDataUtf8Binary: _*).asInstanceOf[Expression] val inputDataLcase = - generateData(replaceExpressions(inputTypes, params.toSeq), Utf8BinaryLcase) + generateData( + replaceExpressions(inputTypes, headConstructor.getParameters.map(p => p.getType).toSeq), + Utf8BinaryLcase + ) val instanceLcase = headConstructor.newInstance(inputDataLcase: _*).asInstanceOf[Expression] val exceptionUtfBinary = { try { - instanceUtf8Binary match { + scala.util.Right(instanceUtf8Binary match { case replaceable: RuntimeReplaceable => replaceable.replacement.eval(EmptyRow) case _ => instanceUtf8Binary.eval(EmptyRow) - } - None + }) } catch { - case e: Throwable => Some(e) + case e: Throwable => scala.util.Left(e) } } val exceptionLcase = { try { - instanceLcase match { + scala.util.Right(instanceLcase match { case replaceable: RuntimeReplaceable => replaceable.replacement.eval(EmptyRow) case _ => instanceLcase.eval(EmptyRow) - } - None + }) } catch { - case e: Throwable => - println(e.getMessage) - Some(e) + case e: Throwable => scala.util.Left(e) } } // Check that both cases either throw or pass - assert(exceptionUtfBinary.isDefined == exceptionLcase.isDefined) + assert(exceptionUtfBinary.isRight == exceptionLcase.isRight) - if (exceptionUtfBinary.isEmpty) { - val resUtf8Binary = instanceUtf8Binary match { - case replaceable: RuntimeReplaceable => - replaceable.replacement.eval(EmptyRow) - case _ => - instanceUtf8Binary.eval(EmptyRow) - } - val resUtf8Lcase = instanceLcase match { - case replaceable: RuntimeReplaceable => - replaceable.replacement.eval(EmptyRow) - case _ => - instanceLcase.eval(EmptyRow) - } + if (exceptionUtfBinary.isRight) { + val resUtf8Binary = exceptionUtfBinary.getOrElse(null) + val resUtf8Lcase = exceptionLcase.getOrElse(null) val dt = instanceLcase.dataType @@ -509,32 +509,34 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } else { - assert(exceptionUtfBinary.get.getClass == exceptionLcase.get.getClass) + assert(exceptionUtfBinary.getOrElse(new Exception()).getClass + == exceptionLcase.getOrElse(new Exception()).getClass) } } } + /** + * This test does following: + * 1) Extract relevant expressions + * 2) Run dataframe select on expressions with different inputs + * 3) Check if both expressions throw an exception + * 4) If no exception, check if the result is the same + * 5) Otherwise, check if exceptions are the same + */ test("SPARK-48280: Expression Walker for codeGen generation") { - /** - * This test does following: - * 1) Extract relevant expressions - * 2) Run dataframe select on expressions with different inputs - * 3) Check if both expressions throw an exception - * 4) If no exception, check if the result is the same - * 5) Otherwise, check if exceptions are the same - */ var (funInfos, toSkip) = extractRelevantExpressions() toSkip = toSkip ++ List( - // Known to be faulty - "to_xml", // These expressions are not called as functions "lead", + "lag", "nth_value", "session_window" ) for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) - val headConstructor = cl.getConstructors.head + val headConstructor = cl.getConstructors + .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) val args = generateData(params.toSeq, Utf8Binary) val expr = headConstructor.newInstance(args: _*).asInstanceOf[ExpectsInputTypes] @@ -556,7 +558,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi .getRows(1, 0) None } catch { - case e: Throwable => Some(e) + case e: Throwable => + println(e.getMessage) + Some(e) } assert(utf8BinaryResult.isDefined === utf8BinaryLcaseResult.isDefined) @@ -577,10 +581,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi utf8BinaryLcaseResult.getRows(1, 0).map(_.map(_.toLowerCase))) // scalastyle:on caselocale case _ => - // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 0)(1) === - utf8BinaryLcaseResult.getRows(1, 0)(1)) - // scalastyle:on caselocale + assert(utf8BinaryResult.getRows(1, 0)(1) === utf8BinaryLcaseResult.getRows(1, 0)(1)) } } else { @@ -590,15 +591,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } } + /** + * This test does following: + * 1) Extract all expressions + * 2) Run example queries for different session level default collations + * 3) Check if both expressions throw an exception + * 4) If no exception, check if the result is the same + * 5) Otherwise, check if exceptions are the same + */ test("SPARK-48280: Expression Walker for SQL query examples") { - /** - * This test does following: - * 1) Extract all expressions - * 2) Run example queries for different session level default collations - * 3) Check if both expressions throw an exception - * 4) If no exception, check if the result is the same - * 5) Otherwise, check if exceptions are the same - */ val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => spark.sessionState.catalog.lookupFunctionInfo(funcId) } @@ -641,8 +642,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val resultUTF8Lcase = sql(m.substring(2)) assert(resultUTF8.collect() === resultUTF8Lcase.collect()) } - } - catch { + } catch { case e: SparkRuntimeException => assert(e.getErrorClass == "USER_RAISED_EXCEPTION") case other: Throwable => throw other } From 001f91da3de9be67288056d2c13d45e2d9eb1f6d Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Jul 2024 08:59:43 +0200 Subject: [PATCH 27/30] Improve code --- .../sql/CollationExpressionWalkerSuite.scala | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index cc25bccbde808..dcb25ea09f762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -120,6 +120,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case _: DoubleType => Literal(5.0) case IntegerType | NumericType | IntegralType => Literal(5) case LongType => Literal(5L) + case NullType => Literal(null) case _: StringType | AnyDataType | _: AbstractStringType => collationType match { case Utf8Binary => @@ -197,6 +198,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case Utf8Binary => "'dummy string' COLLATE UTF8_BINARY" case Utf8BinaryLcase => "'DuMmY sTrInG' COLLATE UTF8_LCASE" } + case NullType => "null" case VariantType => s"parse_json('{}')" case TypeCollection(typeCollection) => val strTypes = typeCollection.filter(hasStringType) @@ -524,16 +526,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for codeGen generation") { - var (funInfos, toSkip) = extractRelevantExpressions() - toSkip = toSkip ++ List( - // These expressions are not called as functions - "lead", - "lag", - "nth_value", - "session_window" - ) + val (funInfos, toSkip) = extractRelevantExpressions() + for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 @@ -543,49 +538,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi withTable("tbl", "tbl_lcase") { - val utf8_df = generateTableData(expr.inputTypes, Utf8Binary) - val utf8_lcase_df = generateTableData(expr.inputTypes, Utf8BinaryLcase) + val utf8_df = generateTableData(expr.inputTypes.take(1), Utf8Binary) + val utf8_lcase_df = generateTableData(expr.inputTypes.take(1), Utf8BinaryLcase) val utf8BinaryResult = try { - utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary)) - .getRows(1, 0) - None + val df = utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary)) + df.getRows(1, 0) + scala.util.Right(df) } catch { - case e: Throwable => Some(e) + case e: Throwable => scala.util.Left(e) } val utf8BinaryLcaseResult = try { - utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8BinaryLcase)) - .getRows(1, 0) - None + val df = utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8BinaryLcase)) + df.getRows(1, 0) + scala.util.Right(df) } catch { - case e: Throwable => - println(e.getMessage) - Some(e) + case e: Throwable => scala.util.Left(e) } - assert(utf8BinaryResult.isDefined === utf8BinaryLcaseResult.isDefined) + assert(utf8BinaryResult.isLeft === utf8BinaryLcaseResult.isLeft) - if (utf8BinaryResult.isEmpty) { - val utf8BinaryResult = - utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary)) - val utf8BinaryLcaseResult = - utf8_lcase_df.selectExpr(transformExpressionToString(expr, Utf8BinaryLcase)) + if (utf8BinaryResult.isRight) { + val utf8BinaryResultChecked = utf8BinaryResult.getOrElse(null) + val utf8BinaryLcaseResultChecked = utf8BinaryLcaseResult.getOrElse(null) - val dt = utf8BinaryResult.schema.fields.head.dataType + val dt = utf8BinaryResultChecked.schema.fields.head.dataType dt match { - case st if utf8BinaryResult != null && utf8BinaryLcaseResult != null && + case st if utf8BinaryResultChecked != null && utf8BinaryLcaseResultChecked != null && hasStringType(st) => // scalastyle:off caselocale - assert(utf8BinaryResult.getRows(1, 0).map(_.map(_.toLowerCase)) === - utf8BinaryLcaseResult.getRows(1, 0).map(_.map(_.toLowerCase))) + assert(utf8BinaryResultChecked.getRows(1, 0).map(_.map(_.toLowerCase)) === + utf8BinaryLcaseResultChecked.getRows(1, 0).map(_.map(_.toLowerCase))) // scalastyle:on caselocale case _ => - assert(utf8BinaryResult.getRows(1, 0)(1) === utf8BinaryLcaseResult.getRows(1, 0)(1)) + assert(utf8BinaryResultChecked.getRows(1, 0)(1) === + utf8BinaryLcaseResultChecked.getRows(1, 0)(1)) } } else { - assert(utf8BinaryResult.get.getClass == utf8BinaryResult.get.getClass) + assert(utf8BinaryResult.getOrElse(new Exception()).getClass + == utf8BinaryResult.getOrElse(new Exception()).getClass) } } } @@ -622,6 +615,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "replace", "grouping", "grouping_id", + "reflect", + "try_reflect", + "java_method", // need to skip as these are random functions "rand", "random", @@ -629,12 +625,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "uuid", "shuffle", // other functions which are not yet supported - "reflect", - "try_reflect", - "java_method" + "to_avro", + "from_avro" ) for (funInfo <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + funInfo.getName) for (m <- "> .*;".r.findAllIn(funInfo.getExamples)) { try { val resultUTF8 = sql(m.substring(2)) From 9082b5e87576b913ae2c34e21a56fc8a781316d5 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Jul 2024 10:08:01 +0200 Subject: [PATCH 28/30] remove printing --- .../org/apache/spark/sql/CollationExpressionWalkerSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index dcb25ea09f762..be078d6aa9411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -630,7 +630,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi ) for (funInfo <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + funInfo.getName) for (m <- "> .*;".r.findAllIn(funInfo.getExamples)) { try { val resultUTF8 = sql(m.substring(2)) From 10a358bf4582984d3003ec5f1f5da1bb0b98d0d7 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Tue, 2 Jul 2024 13:01:56 +0200 Subject: [PATCH 29/30] Run CIs with printing --- .../org/apache/spark/sql/CollationExpressionWalkerSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index be078d6aa9411..a0d9079348ba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -529,6 +529,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val (funInfos, toSkip) = extractRelevantExpressions() for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { + println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 From 57d11c886c2f297d9e8eb97c393c91d185caf7e3 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 3 Jul 2024 11:57:43 +0200 Subject: [PATCH 30/30] Remove printing --- .../apache/spark/sql/CollationExpressionWalkerSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index a0d9079348ba8..d582167478da9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -529,7 +529,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi val (funInfos, toSkip) = extractRelevantExpressions() for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - println("checking - " + f.getName) val cl = Utils.classForName(f.getClassName) val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 @@ -539,8 +538,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi withTable("tbl", "tbl_lcase") { - val utf8_df = generateTableData(expr.inputTypes.take(1), Utf8Binary) - val utf8_lcase_df = generateTableData(expr.inputTypes.take(1), Utf8BinaryLcase) + val utf8_df = generateTableData(expr.inputTypes.take(2), Utf8Binary) + val utf8_lcase_df = generateTableData(expr.inputTypes.take(2), Utf8BinaryLcase) val utf8BinaryResult = try { val df = utf8_df.selectExpr(transformExpressionToString(expr, Utf8Binary))