Skip to content

Commit

Permalink
Test and fix cast from bigint to varchar
Browse files Browse the repository at this point in the history
Cherry-pick of trinodb/trino#10090

Co-authored-by: kasiafi <[email protected]>
  • Loading branch information
v-jizhang and kasiafi committed Feb 3, 2022
1 parent 7c3db27 commit d4ca4db
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import static com.facebook.presto.common.function.OperatorType.XX_HASH_64;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.StandardErrorCode.DIVISION_BY_ZERO;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.Float.floatToRawIntBits;
Expand Down Expand Up @@ -270,10 +271,16 @@ public static long castToReal(@SqlType(StandardTypes.BIGINT) long value)
@ScalarOperator(CAST)
@LiteralParameters("x")
@SqlType("varchar(x)")
public static Slice castToVarchar(@SqlType(StandardTypes.BIGINT) long value)
public static Slice castToVarchar(@LiteralParameter("x") long x, @SqlType(StandardTypes.BIGINT) long value)
{
// todo optimize me
return utf8Slice(String.valueOf(value));
String stringValue = String.valueOf(value);
// String is all-ASCII, so String.length() here returns actual code points count
if (stringValue.length() <= x) {
return utf8Slice(stringValue);
}

throw new PrestoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", value, x));
}

@ScalarOperator(HASH_CODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;

Expand Down Expand Up @@ -152,6 +153,11 @@ protected void assertInvalidFunction(String projection, ErrorCodeSupplier expect
functionAssertions.assertInvalidFunction(projection, expectedErrorCode);
}

protected void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class<? extends Throwable> throwableClass, @Language("RegExp") String message)
{
functionAssertions.assertFunctionThrowsIncorrectly(projection, throwableClass, message);
}

protected void assertNumericOverflow(String projection, String message)
{
functionAssertions.assertNumericOverflow(projection, message);
Expand All @@ -162,7 +168,7 @@ protected void assertInvalidCast(String projection)
functionAssertions.assertInvalidCast(projection);
}

protected void assertInvalidCast(String projection, String message)
protected void assertInvalidCast(@Language("SQL") String projection, String message)
{
functionAssertions.assertInvalidCast(projection, message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.intellij.lang.annotations.Language;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.openjdk.jol.info.ClassLayout;
Expand Down Expand Up @@ -148,6 +149,7 @@
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
Expand Down Expand Up @@ -404,6 +406,13 @@ public void assertInvalidFunction(String projection, ErrorCodeSupplier expectedE
}
}

public void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class<? extends Throwable> throwableClass, @Language("RegExp") String message)
{
assertThatThrownBy(() -> evaluateInvalid(projection))
.isInstanceOf(throwableClass)
.hasMessageMatching(message);
}

public void assertNumericOverflow(String projection, String message)
{
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE;
Expand All @@ -102,6 +103,7 @@
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public class TestExpressionInterpreter
{
Expand Down Expand Up @@ -584,6 +586,42 @@ public void testCastToString()
// TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'");
}

@Test
public void testCastBigintToBoundedVarchar()
{
assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'");
assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'");

try {
evaluate("CAST(12300000000 AS varchar(3))", true);
fail("Expected to throw an INVALID_CAST_ARGUMENT exception");
}
catch (PrestoException e) {
try {
assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)");
}
catch (Throwable failure) {
failure.addSuppressed(e);
throw failure;
}
}

try {
evaluate("CAST(-12300000000 AS varchar(3))", true);
}
catch (PrestoException e) {
try {
assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)");
}
catch (Throwable failure) {
failure.addSuppressed(e);
throw failure;
}
}
}

@Test
public void testCastToBoolean()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
Expand All @@ -23,6 +24,7 @@
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
Expand All @@ -42,6 +44,7 @@
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.sql.ExpressionUtils.binaryExpression;
import static com.facebook.presto.sql.ExpressionUtils.extractPredicates;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
Expand All @@ -51,6 +54,7 @@
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;

public class TestSimplifyExpressions
{
Expand Down Expand Up @@ -132,6 +136,45 @@ public void testExtractCommonPredicates()
" OR (A51 AND A52) OR (A53 AND A54) OR (A55 AND A56) OR (A57 AND A58) OR (A59 AND A60)");
}

@Test
public void testCastBigintToBoundedVarchar()
{
// the varchar type length is enough to contain the number's representation
assertSimplifies("CAST(12300000000 AS varchar(11))", "'12300000000'");
// The last argument "'-12300000000'" is varchar(12). Need varchar(50) to the following test pass.
//assertSimplifies("CAST(-12300000000 AS varchar(50))", "CAST('-12300000000' AS varchar(50))", "'-12300000000'");

// cast from bigint to varchar fails, so the expression is not modified
try {
assertSimplifies("CAST(12300000000 AS varchar(3))", "CAST(12300000000 AS varchar(3))");
fail("Expected to throw an PrestoException exception");
}
catch (PrestoException e) {
try {
assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)");
}
catch (Throwable failure) {
failure.addSuppressed(e);
throw failure;
}
}

try {
assertSimplifies("CAST(-12300000000 AS varchar(3))", "CAST(-12300000000 AS varchar(3))");
}
catch (PrestoException e) {
try {
assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)");
}
catch (Throwable failure) {
failure.addSuppressed(e);
throw failure;
}
}
}

private static void assertSimplifies(String expression, String expected)
{
assertSimplifies(expression, expected, null);
Expand Down Expand Up @@ -177,5 +220,13 @@ public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, V
.collect(toList());
return binaryExpression(node.getOperator(), predicates);
}

@Override
public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
// the `expected` Cast expression comes out of the AstBuilder with the `typeOnly` flag set to false.
// always set the `typeOnly` flag to false so that it does not break the comparison.
return new Cast(node.getExpression(), node.getType(), node.isSafe(), false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static java.lang.String.format;

public class TestBigintOperators
Expand Down Expand Up @@ -188,31 +189,34 @@ public void testCastToBigint()
@Test
public void testCastToVarchar()
{
assertFunction("cast(37 as varchar)", VARCHAR, "37");
assertFunction("cast(BIGINT '37' as varchar)", VARCHAR, "37");
assertFunction("cast(100000000017 as varchar)", VARCHAR, "100000000017");
assertFunction("cast(100000000017 as varchar(13))", createVarcharType(13), "100000000017");
assertFunction("cast(100000000017 as varchar(50))", createVarcharType(50), "100000000017");
assertInvalidCast("cast(100000000017 as varchar(2))", "Value 100000000017 cannot be represented as varchar(2)");
}

@Test
public void testCastToDouble()
{
assertFunction("cast(37 as double)", DOUBLE, 37.0);
assertFunction("cast(BIGINT '37' as double)", DOUBLE, 37.0);
assertFunction("cast(100000000017 as double)", DOUBLE, 100000000017.0);
}

@Test
public void testCastToFloat()
{
assertFunction("cast(37 as real)", REAL, 37.0f);
assertFunction("cast(BIGINT '37' as real)", REAL, 37.0f);
assertFunction("cast(-100000000017 as real)", REAL, -100000000017.0f);
assertFunction("cast(0 as real)", REAL, 0.0f);
assertFunction("cast(BIGINT '0' as real)", REAL, 0.0f);
}

@Test
public void testCastToBoolean()
{
assertFunction("cast(37 as boolean)", BOOLEAN, true);
assertFunction("cast(BIGINT '37' as boolean)", BOOLEAN, true);
assertFunction("cast(100000000017 as boolean)", BOOLEAN, true);
assertFunction("cast(0 as boolean)", BOOLEAN, false);
assertFunction("cast(BIGINT '0' as boolean)", BOOLEAN, false);
}

@Test
Expand Down

0 comments on commit d4ca4db

Please sign in to comment.