Skip to content

Commit

Permalink
Fix cast from decimal to varchar
Browse files Browse the repository at this point in the history
Fail if the resulting string does not fit in the
bounded length of the varchar type.
Applies both to short and long decimal types.
  • Loading branch information
kasiafi committed Dec 6, 2021
1 parent 8023e3d commit a1e6733
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 37 deletions.
47 changes: 42 additions & 5 deletions core/trino-main/src/main/java/io/trino/type/DecimalCasts.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.spi.type.Decimals;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarcharType;
import io.trino.util.JsonCastException;

import java.io.IOException;
Expand Down Expand Up @@ -62,6 +63,7 @@
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.rescale;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToUnscaledLong;
import static io.trino.spi.type.VarcharType.UNBOUNDED_LENGTH;
import static io.trino.type.JsonType.JSON;
import static io.trino.util.Failures.checkCondition;
import static io.trino.util.JsonUtil.createJsonGenerator;
Expand Down Expand Up @@ -91,7 +93,6 @@ public final class DecimalCasts
public static final SqlScalarFunction DOUBLE_TO_DECIMAL_CAST = castFunctionToDecimalFrom(DOUBLE.getTypeSignature(), "doubleToShortDecimal", "doubleToLongDecimal");
public static final SqlScalarFunction DECIMAL_TO_REAL_CAST = castFunctionFromDecimalTo(REAL.getTypeSignature(), "shortDecimalToReal", "longDecimalToReal");
public static final SqlScalarFunction REAL_TO_DECIMAL_CAST = castFunctionToDecimalFrom(REAL.getTypeSignature(), "realToShortDecimal", "realToLongDecimal");
public static final SqlScalarFunction DECIMAL_TO_VARCHAR_CAST = castFunctionFromDecimalTo(new TypeSignature("varchar", typeVariable("x")), "shortDecimalToVarchar", "longDecimalToVarchar");
public static final SqlScalarFunction VARCHAR_TO_DECIMAL_CAST = castFunctionToDecimalFrom(new TypeSignature("varchar", typeVariable("x")), "varcharToShortDecimal", "varcharToLongDecimal");
public static final SqlScalarFunction DECIMAL_TO_JSON_CAST = castFunctionFromDecimalTo(JSON.getTypeSignature(), "shortDecimalToJson", "longDecimalToJson");
public static final SqlScalarFunction JSON_TO_DECIMAL_CAST = castFunctionToDecimalFromBuilder(JSON.getTypeSignature(), true, "jsonToShortDecimal", "jsonToLongDecimal");
Expand Down Expand Up @@ -157,6 +158,30 @@ private static SqlScalarFunction castFunctionToDecimalFromBuilder(TypeSignature
}))).build();
}

public static final SqlScalarFunction DECIMAL_TO_VARCHAR_CAST = new PolymorphicScalarFunctionBuilder(DecimalCasts.class)
.signature(Signature.builder()
.operatorType(CAST)
.argumentTypes(new TypeSignature("decimal", typeVariable("precision"), typeVariable("scale")))
.returnType(new TypeSignature("varchar", typeVariable("x")))
.build())
.deterministic(true)
.choice(choice -> choice
.implementation(methodsGroup -> methodsGroup
.methods("shortDecimalToVarchar", "longDecimalToVarchar")
.withExtraParameters((context) -> {
long scale = context.getLiteral("scale");
VarcharType resultType = (VarcharType) context.getReturnType();
long length;
if (resultType.isUnbounded()) {
length = UNBOUNDED_LENGTH;
}
else {
length = resultType.getBoundedLength();
}
return ImmutableList.of(scale, length);
})))
.build();

private DecimalCasts() {}

@UsedByGeneratedCode
Expand Down Expand Up @@ -457,15 +482,27 @@ public static Slice realToLongDecimal(long value, long precision, long scale, Bi
}

@UsedByGeneratedCode
public static Slice shortDecimalToVarchar(long decimal, long precision, long scale, long tenToScale)
public static Slice shortDecimalToVarchar(long decimal, long scale, long varcharLength)
{
return utf8Slice(Decimals.toString(decimal, DecimalConversions.intScale(scale)));
String stringValue = Decimals.toString(decimal, DecimalConversions.intScale(scale));
// String is all-ASCII, so String.length() here returns actual code points count
if (stringValue.length() <= varcharLength) {
return utf8Slice(stringValue);
}

throw new TrinoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", stringValue, varcharLength));
}

@UsedByGeneratedCode
public static Slice longDecimalToVarchar(Slice decimal, long precision, long scale, BigInteger tenToScale)
public static Slice longDecimalToVarchar(Slice decimal, long scale, long varcharLength)
{
return utf8Slice(Decimals.toString(decimal, DecimalConversions.intScale(scale)));
String stringValue = Decimals.toString(decimal, DecimalConversions.intScale(scale));
// String is all-ASCII, so String.length() here returns actual code points count
if (stringValue.length() <= varcharLength) {
return utf8Slice(stringValue);
}

throw new TrinoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", stringValue, varcharLength));
}

@UsedByGeneratedCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,23 +687,39 @@ public void testCastDecimalToBoundedVarchar()
assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(4))", "'12.4'");
assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(50))", "'12.4'");

// short decimal: incorrect behavior: the result value does not fit in the type
assertEvaluatedEquals("CAST(DECIMAL '12.4' AS varchar(3))", "'12.4'");
assertEvaluatedEquals("CAST(DECIMAL '-12.4' AS varchar(3))", "'-12.4'");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '12.4' AS varchar(3))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value 12.4 cannot be represented as varchar(3)");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-12.4' AS varchar(3))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value -12.4 cannot be represented as varchar(3)");

// the trailing 0 does not fit in the type
assertEvaluatedEquals("CAST(DECIMAL '12.40' AS varchar(4))", "'12.40'");
assertEvaluatedEquals("CAST(DECIMAL '-12.40' AS varchar(5))", "'-12.40'");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '12.40' AS varchar(4))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value 12.40 cannot be represented as varchar(4)");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-12.40' AS varchar(5))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value -12.40 cannot be represented as varchar(5)");

// long decimal
assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(20))", "'100000000000000000.1'");
assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(50))", "'100000000000000000.1'");

// long decimal: incorrect behavior: the result value does not fit in the type
assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "'100000000000000000.1'");
assertEvaluatedEquals("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "'-100000000000000000.1'");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '100000000000000000.1' AS varchar(3))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value 100000000000000000.1 cannot be represented as varchar(3)");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value -100000000000000000.1 cannot be represented as varchar(3)");

// the trailing 0 does not fit in the type
assertEvaluatedEquals("CAST(DECIMAL '100000000000000000.10' AS varchar(20))", "'100000000000000000.10'");
assertEvaluatedEquals("CAST(DECIMAL '-100000000000000000.10' AS varchar(21))", "'-100000000000000000.10'");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '100000000000000000.10' AS varchar(20))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value 100000000000000000.10 cannot be represented as varchar(20)");
assertTrinoExceptionThrownBy(() -> evaluate("CAST(DECIMAL '-100000000000000000.10' AS varchar(21))"))
.hasErrorCode(INVALID_CAST_ARGUMENT)
.hasMessage("Value -100000000000000000.10 cannot be represented as varchar(21)");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,10 @@ public void testCastShortDecimalToBoundedVarchar()
assertSimplifies("CAST(DECIMAL '12.4' AS varchar(4))", "'12.4'");
assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(50))", "CAST('-12.4' AS varchar(50))");

// the varchar type length is not enough to contain the number's representation:
// the cast operator returns a value that is too long for the expected type ('12.4' for varchar(3))
// the value is then wrapped in another cast by the LiteralEncoder (CAST('12.4' AS varchar(3))),
// so eventually we get a truncated string '12.'
assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3))", "CAST('12.4' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(3))", "CAST('-12.4' AS varchar(3))");

// the cast operator returns a value that is too long for the expected type ('12.4' for varchar(3))
// the value is nested in a comparison expression, so it is not truncated by the LiteralEncoder
assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'", "true");
// cast from short decimal to varchar fails, so the expression is not modified
assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3))", "CAST(DECIMAL '12.4' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '-12.4' AS varchar(3))", "CAST(DECIMAL '-12.4' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'", "CAST(DECIMAL '12.4' AS varchar(3)) = '12.4'");
}

@Test
Expand All @@ -206,16 +200,10 @@ public void testCastLongDecimalToBoundedVarchar()
assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(20))", "'100000000000000000.1'");
assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(50))", "CAST('-100000000000000000.1' AS varchar(50))");

// the varchar type length is not enough to contain the number's representation:
// the cast operator returns a value that is too long for the expected type ('100000000000000000.1' for varchar(3))
// the value is then wrapped in another cast by the LiteralEncoder (CAST('100000000000000000.1' AS varchar(3))),
// so eventually we get a truncated string '100'
assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "CAST('100000000000000000.1' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "CAST('-100000000000000000.1' AS varchar(3))");

// the cast operator returns a value that is too long for the expected type ('100000000000000000.1' for varchar(3))
// the value is nested in a comparison expression, so it is not truncated by the LiteralEncoder
assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'", "true");
// cast from long decimal to varchar fails, so the expression is not modified
assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3))", "CAST(DECIMAL '100000000000000000.1' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '-100000000000000000.1' AS varchar(3))", "CAST(DECIMAL '-100000000000000000.1' AS varchar(3))");
assertSimplifies("CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'", "CAST(DECIMAL '100000000000000000.1' AS varchar(3)) = '100000000000000000.1'");
}

private static void assertSimplifies(String expression, String expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,10 @@ public void testDecimalToVarcharCasts()

assertFunction("cast(DECIMAL '12.4' as varchar(4))", createVarcharType(4), "12.4");
assertFunction("cast(DECIMAL '12.4' as varchar(50))", createVarcharType(50), "12.4");
assertFunctionThrowsIncorrectly("cast(DECIMAL '12.4' as varchar(3))", IllegalArgumentException.class, "Character count exceeds length limit 3.*");
assertInvalidCast("cast(DECIMAL '12.4' as varchar(3))", "Value 12.4 cannot be represented as varchar(3)");

assertFunction("cast(DECIMAL '100000000000000000.1' as varchar(20))", createVarcharType(20), "100000000000000000.1");
assertFunction("cast(DECIMAL '100000000000000000.1' as varchar(50))", createVarcharType(50), "100000000000000000.1");
assertFunctionThrowsIncorrectly("cast(DECIMAL '100000000000000000.1' as varchar(19))", IllegalArgumentException.class, "Character count exceeds length limit 19.*");
assertInvalidCast("cast(DECIMAL '100000000000000000.1' as varchar(19))", "Value 100000000000000000.1 cannot be represented as varchar(19)");
}
}

0 comments on commit a1e6733

Please sign in to comment.