From a1e67330478cf5eb19b14e2b524321a02eeaded0 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:17:06 +0100 Subject: [PATCH] Fix cast from decimal to varchar Fail if the resulting string does not fit in the bounded length of the varchar type. Applies both to short and long decimal types. --- .../main/java/io/trino/type/DecimalCasts.java | 47 +++++++++++++++++-- .../trino/sql/TestExpressionInterpreter.java | 36 ++++++++++---- .../rule/TestSimplifyExpressions.java | 28 ++++------- .../java/io/trino/type/TestDecimalCasts.java | 4 +- 4 files changed, 78 insertions(+), 37 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java index 9f92a2dad48e..71ce832ca28d 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java @@ -31,6 +31,7 @@ import io.trino.spi.type.Decimals; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.VarcharType; import io.trino.util.JsonCastException; import java.io.IOException; @@ -62,6 +63,7 @@ import static io.trino.spi.type.UnscaledDecimal128Arithmetic.rescale; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal; import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToUnscaledLong; +import static io.trino.spi.type.VarcharType.UNBOUNDED_LENGTH; import static io.trino.type.JsonType.JSON; import static io.trino.util.Failures.checkCondition; import static io.trino.util.JsonUtil.createJsonGenerator; @@ -91,7 +93,6 @@ public final class DecimalCasts public static final SqlScalarFunction DOUBLE_TO_DECIMAL_CAST = castFunctionToDecimalFrom(DOUBLE.getTypeSignature(), "doubleToShortDecimal", "doubleToLongDecimal"); public static final SqlScalarFunction DECIMAL_TO_REAL_CAST = castFunctionFromDecimalTo(REAL.getTypeSignature(), "shortDecimalToReal", "longDecimalToReal"); public static final SqlScalarFunction REAL_TO_DECIMAL_CAST = castFunctionToDecimalFrom(REAL.getTypeSignature(), "realToShortDecimal", "realToLongDecimal"); - public static final SqlScalarFunction DECIMAL_TO_VARCHAR_CAST = castFunctionFromDecimalTo(new TypeSignature("varchar", typeVariable("x")), "shortDecimalToVarchar", "longDecimalToVarchar"); public static final SqlScalarFunction VARCHAR_TO_DECIMAL_CAST = castFunctionToDecimalFrom(new TypeSignature("varchar", typeVariable("x")), "varcharToShortDecimal", "varcharToLongDecimal"); public static final SqlScalarFunction DECIMAL_TO_JSON_CAST = castFunctionFromDecimalTo(JSON.getTypeSignature(), "shortDecimalToJson", "longDecimalToJson"); public static final SqlScalarFunction JSON_TO_DECIMAL_CAST = castFunctionToDecimalFromBuilder(JSON.getTypeSignature(), true, "jsonToShortDecimal", "jsonToLongDecimal"); @@ -157,6 +158,30 @@ private static SqlScalarFunction castFunctionToDecimalFromBuilder(TypeSignature }))).build(); } + public static final SqlScalarFunction DECIMAL_TO_VARCHAR_CAST = new PolymorphicScalarFunctionBuilder(DecimalCasts.class) + .signature(Signature.builder() + .operatorType(CAST) + .argumentTypes(new TypeSignature("decimal", typeVariable("precision"), typeVariable("scale"))) + .returnType(new TypeSignature("varchar", typeVariable("x"))) + .build()) + .deterministic(true) + .choice(choice -> choice + .implementation(methodsGroup -> methodsGroup + .methods("shortDecimalToVarchar", "longDecimalToVarchar") + .withExtraParameters((context) -> { + long scale = context.getLiteral("scale"); + VarcharType resultType = (VarcharType) context.getReturnType(); + long length; + if (resultType.isUnbounded()) { + length = UNBOUNDED_LENGTH; + } + else { + length = resultType.getBoundedLength(); + } + return ImmutableList.of(scale, length); + }))) + .build(); + private DecimalCasts() {} @UsedByGeneratedCode @@ -457,15 +482,27 @@ public static Slice realToLongDecimal(long value, long precision, long scale, Bi } @UsedByGeneratedCode - public static Slice shortDecimalToVarchar(long decimal, long precision, long scale, long tenToScale) + public static Slice shortDecimalToVarchar(long decimal, long scale, long varcharLength) { - return utf8Slice(Decimals.toString(decimal, DecimalConversions.intScale(scale))); + String stringValue = Decimals.toString(decimal, DecimalConversions.intScale(scale)); + // String is all-ASCII, so String.length() here returns actual code points count + if (stringValue.length() <= varcharLength) { + return utf8Slice(stringValue); + } + + throw new TrinoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", stringValue, varcharLength)); } @UsedByGeneratedCode - public static Slice longDecimalToVarchar(Slice decimal, long precision, long scale, BigInteger tenToScale) + public static Slice longDecimalToVarchar(Slice decimal, long scale, long varcharLength) { - return utf8Slice(Decimals.toString(decimal, DecimalConversions.intScale(scale))); + String stringValue = Decimals.toString(decimal, DecimalConversions.intScale(scale)); + // String is all-ASCII, so String.length() here returns actual code points count + if (stringValue.length() <= varcharLength) { + return utf8Slice(stringValue); + } + + throw new TrinoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", stringValue, varcharLength)); } @UsedByGeneratedCode diff --git a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java index 8a92c21111db..99e28e96fc91 100644 --- a/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java +++ b/core/trino-main/src/test/java/io/trino/sql/TestExpressionInterpreter.java @@ -687,23 +687,39 @@ public void testCastDecimalToBoundedVarchar() assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(4))", "'12.4'"); assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(50))", "'12.4'"); - // short decimal: incorrect behavior: the result value does not fit in the type - assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(3))", "'12.4'"); - assertEvaluatedEquals("CAST(DECIMAL '-12.4' AS varchar(3))", "'-12.4'"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '12.4' AS varchar(3))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value 12.4 cannot be represented as varchar(3)"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-12.4' AS varchar(3))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value -12.4 cannot be represented as varchar(3)"); + // the trailing 0 does not fit in the type - assertEvaluatedEquals("CAST(DECIMAL '12.40' AS varchar(4))", "'12.40'"); - assertEvaluatedEquals("CAST(DECIMAL '-12.40' AS varchar(5))", "'-12.40'"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '12.40' AS varchar(4))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value 12.40 cannot be represented as varchar(4)"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-12.40' AS varchar(5))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value -12.40 cannot be represented as varchar(5)"); // long decimal assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(20))", "'100000000000000000.1'"); assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(50))", "'100000000000000000.1'"); - // long decimal: incorrect behavior: the result value does not fit in the type - assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "'100000000000000000.1'"); - assertEvaluatedEquals("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "'-100000000000000000.1'"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '100000000000000000.1' AS varchar(3))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value 100000000000000000.1 cannot be represented as varchar(3)"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value -100000000000000000.1 cannot be represented as varchar(3)"); + // the trailing 0 does not fit in the type - assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.10' AS varchar(20))", "'100000000000000000.10'"); - assertEvaluatedEquals("CAST(DECIMAL '-100000000000000000.10' AS varchar(21))", "'-100000000000000000.10'"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '100000000000000000.10' AS varchar(20))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value 100000000000000000.10 cannot be represented as varchar(20)"); + assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-100000000000000000.10' AS varchar(21))")) + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Value -100000000000000000.10 cannot be represented as varchar(21)"); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java index b04384e172c0..1a317eeea7af 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -187,16 +187,10 @@ public void testCastShortDecimalToBoundedVarchar() assertSimplifies("CAST(DECIMAL '12.4' AS varchar(4))", "'12.4'"); assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(50))", "CAST('-12.4' AS varchar(50))"); - // the varchar type length is not enough to contain the number's representation: - // the cast operator returns a value that is too long for the expected type ('12.4' for varchar(3)) - // the value is then wrapped in another cast by the LiteralEncoder (CAST('12.4' AS varchar(3))), - // so eventually we get a truncated string '12.' - assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3))", "CAST('12.4' AS varchar(3))"); - assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(3))", "CAST('-12.4' AS varchar(3))"); - - // the cast operator returns a value that is too long for the expected type ('12.4' for varchar(3)) - // the value is nested in a comparison expression, so it is not truncated by the LiteralEncoder - assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'", "true"); + // cast from short decimal to varchar fails, so the expression is not modified + assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3))", "CAST(DECIMAL '12.4' AS varchar(3))"); + assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(3))", "CAST(DECIMAL '-12.4' AS varchar(3))"); + assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'", "CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'"); } @Test @@ -206,16 +200,10 @@ public void testCastLongDecimalToBoundedVarchar() assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(20))", "'100000000000000000.1'"); assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(50))", "CAST('-100000000000000000.1' AS varchar(50))"); - // the varchar type length is not enough to contain the number's representation: - // the cast operator returns a value that is too long for the expected type ('100000000000000000.1' for varchar(3)) - // the value is then wrapped in another cast by the LiteralEncoder (CAST('100000000000000000.1' AS varchar(3))), - // so eventually we get a truncated string '100' - assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "CAST('100000000000000000.1' AS varchar(3))"); - assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "CAST('-100000000000000000.1' AS varchar(3))"); - - // the cast operator returns a value that is too long for the expected type ('100000000000000000.1' for varchar(3)) - // the value is nested in a comparison expression, so it is not truncated by the LiteralEncoder - assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'", "true"); + // cast from long decimal to varchar fails, so the expression is not modified + assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "CAST(DECIMAL '100000000000000000.1' AS varchar(3))"); + assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "CAST(DECIMAL '-100000000000000000.1' AS varchar(3))"); + assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'", "CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'"); } private static void assertSimplifies(String expression, String expected) diff --git a/core/trino-main/src/test/java/io/trino/type/TestDecimalCasts.java b/core/trino-main/src/test/java/io/trino/type/TestDecimalCasts.java index 54733e9e8387..6f5ab0083c75 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestDecimalCasts.java +++ b/core/trino-main/src/test/java/io/trino/type/TestDecimalCasts.java @@ -472,10 +472,10 @@ public void testDecimalToVarcharCasts() assertFunction("cast(DECIMAL '12.4' as varchar(4))", createVarcharType(4), "12.4"); assertFunction("cast(DECIMAL '12.4' as varchar(50))", createVarcharType(50), "12.4"); - assertFunctionThrowsIncorrectly("cast(DECIMAL '12.4' as varchar(3))", IllegalArgumentException.class, "Character count exceeds length limit 3.*"); + assertInvalidCast("cast(DECIMAL '12.4' as varchar(3))", "Value 12.4 cannot be represented as varchar(3)"); assertFunction("cast(DECIMAL '100000000000000000.1' as varchar(20))", createVarcharType(20), "100000000000000000.1"); assertFunction("cast(DECIMAL '100000000000000000.1' as varchar(50))", createVarcharType(50), "100000000000000000.1"); - assertFunctionThrowsIncorrectly("cast(DECIMAL '100000000000000000.1' as varchar(19))", IllegalArgumentException.class, "Character count exceeds length limit 19.*"); + assertInvalidCast("cast(DECIMAL '100000000000000000.1' as varchar(19))", "Value 100000000000000000.1 cannot be represented as varchar(19)"); } }