Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49683][SQL] Block trim collation #48336

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* AbstractStringType is an abstract class for StringType with collation support.
* AbstractStringType is an abstract class for StringType with collation support. As every type of
* collation can support trim specifier this class is parametrized with it.
*/
abstract class AbstractStringType extends AbstractDataType {
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
private[sql] def canUseTrimCollation(other: DataType): Boolean =
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
*/
case object StringTypeBinary extends AbstractStringType {
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
canUseTrimCollation(other)
}

object StringTypeBinary extends StringTypeBinary(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinary = {
new StringTypeBinary(supportsTrimCollation)
}
jovanpavl-db marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
*/
case object StringTypeBinaryLcase extends AbstractStringType {
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8LcaseCollation)
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)
}

object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = {
new StringTypeBinaryLcase(supportsTrimCollation)
}
}

/**
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
*/
case object StringTypeWithCaseAccentSensitivity extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
case class StringTypeWithCaseAccentSensitivity(
override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && canUseTrimCollation(other)
}

object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
}
}

/**
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
* CS_AI collation types.
*/
case object StringTypeNonCSAICollation extends AbstractStringType {
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
canUseTrimCollation(other)
}

object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = {
new StringTypeNonCSAICollation(supportsTrimCollation)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
private[sql] def isNonCSAI: Boolean =
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.usesTrimCollation(collationId)

private[sql] def isUTF8BinaryCollation: Boolean =
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ case class HllSketchAgg(

override def inputTypes: Seq[AbstractDataType] =
Seq(
TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType),
TypeCollection(
IntegerType,
LongType,
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
BinaryType),
IntegerType)

override def dataType: DataType = BinaryType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -115,5 +116,6 @@ case class Collation(child: Expression)
val collationName = CollationFactory.fetchCollation(collationId).collationName
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite
StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
Map("1" -> "A", "2" -> "B", "3" -> "C"))
)
val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null)
val unsupportedTestCases = Seq(
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null))
testCases.foreach(t => {
// Unit test.
val text = Literal.create(t.text, StringType(t.collation))
Expand All @@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite
}
})
// Test unsupported collation.
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) {
val query =
s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " +
s"'${unsupportedTestCase.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " +
"'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""),
"paramIndex" -> "first",
"inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"",
"inputType" -> "\"STRING COLLATE UNICODE_AI\"",
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
unsupportedTestCases.foreach(t => {
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
val query =
s"select str_to_map('${t.text}', '${t.pairDelim}', " +
s"'${t.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " +
"'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""),
"paramIndex" -> "first",
"inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""),
"inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""),
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
})
}

test("Support RaiseError misc expression with collation") {
Expand Down