From 3e69b40300af24482e690fa0ca47e1141cea91e2 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Sat, 5 Oct 2024 09:16:40 +0900 Subject: [PATCH] [SPARK-49683][SQL] Block trim collation ### What changes were proposed in this pull request? Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted. ### Why are the changes needed? Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No additional tests, just added field that's not used. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48336 from jovanpavl-db/block-collation-trim. Lead-authored-by: Jovan Pavlovic Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../internal/types/AbstractStringType.scala | 56 +++++++++++++++---- .../apache/spark/sql/types/StringType.scala | 3 + .../catalyst/expressions/CollationKey.scala | 3 +- .../aggregate/datasketchesAggregates.scala | 6 +- .../expressions/collationExpressions.scala | 6 +- .../sql/CollationSQLExpressionsSuite.scala | 52 +++++++++-------- 6 files changed, 89 insertions(+), 37 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 6feb662632763..c3643f4bd15be 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * AbstractStringType is an abstract class for StringType with collation support. + * AbstractStringType is an abstract class for StringType with collation support. As every type of + * collation can support trim specifier this class is parametrized with it. */ -abstract class AbstractStringType extends AbstractDataType { +abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false) + extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" + private[sql] def canUseTrimCollation(other: DataType): Boolean = + supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation } /** * Use StringTypeBinary for expressions supporting only binary collation. */ -case object StringTypeBinary extends AbstractStringType { +case class StringTypeBinary(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality && + canUseTrimCollation(other) +} + +object StringTypeBinary extends StringTypeBinary(false) { + def apply(supportsTrimCollation: Boolean): StringTypeBinary = { + new StringTypeBinary(supportsTrimCollation) + } } /** * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. */ -case object StringTypeBinaryLcase extends AbstractStringType { +case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality || - other.asInstanceOf[StringType].isUTF8LcaseCollation) + other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other) +} + +object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) { + def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = { + new StringTypeBinaryLcase(supportsTrimCollation) + } } /** * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary * and ICU) but limited to using case and accent sensitivity specifiers. */ -case object StringTypeWithCaseAccentSensitivity extends AbstractStringType { - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +case class StringTypeWithCaseAccentSensitivity( + override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && canUseTrimCollation(other) +} + +object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) { + def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = { + new StringTypeWithCaseAccentSensitivity(supportsTrimCollation) + } } /** * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except * CS_AI collation types. */ -case object StringTypeNonCSAICollation extends AbstractStringType { +case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI && + canUseTrimCollation(other) +} + +object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) { + def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = { + new StringTypeNonCSAICollation(supportsTrimCollation) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index c2dd6cec7ba74..29d48e3d1f47f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def isNonCSAI: Boolean = !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) + private[sql] def usesTrimCollation: Boolean = + CollationFactory.usesTrimCollation(collationId) + private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 28ec8482e5cdd..81bafda54135f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index 78bd02d5703cd..a6448051a3996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -106,7 +106,11 @@ case class HllSketchAgg( override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType), + TypeCollection( + IntegerType, + LongType, + StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true), + BinaryType), IntegerType) override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b67e66323bbbd..effcdc4b038e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -115,5 +116,6 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, SQLConf.get.defaultStringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 851160d2fbb94..4c3cd93873bd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) - val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null) + val unsupportedTestCases = Seq( + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null)) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite } }) // Test unsupported collation. - withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { - val query = - s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " + - s"'${unsupportedTestCase.keyValueDelim}')" - checkError( - exception = intercept[AnalysisException] { - sql(query).collect() - }, - condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = Some("42K09"), - parameters = Map( - "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " + - "'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""), - "paramIndex" -> "first", - "inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"", - "inputType" -> "\"STRING COLLATE UNICODE_AI\"", - "requiredType" -> "\"STRING\""), - context = ExpectedContext( - fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", - start = 7, - stop = 41)) - } + unsupportedTestCases.foreach(t => { + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) { + val query = + s"select str_to_map('${t.text}', '${t.pairDelim}', " + + s"'${t.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " + + "'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""), + "paramIndex" -> "first", + "inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""), + "inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""), + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } + }) } test("Support RaiseError misc expression with collation") {