From d4ca4dbcb92d4d40c23e0af6cec2d843bfbdfde7 Mon Sep 17 00:00:00 2001 From: v-jizhang Date: Mon, 3 Jan 2022 08:45:09 -0800 Subject: [PATCH] Test and fix cast from bigint to varchar Cherry-pick of https://github.com/trinodb/trino/pull/10090 Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> --- .../facebook/presto/type/BigintOperators.java | 11 +++- .../scalar/AbstractTestFunctions.java | 8 ++- .../operator/scalar/FunctionAssertions.java | 9 ++++ .../presto/sql/TestExpressionInterpreter.java | 38 ++++++++++++++ .../rule/TestSimplifyExpressions.java | 51 +++++++++++++++++++ .../presto/type/TestBigintOperators.java | 16 +++--- 6 files changed, 124 insertions(+), 9 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java index c1a156d1d82e1..92dc55dd997db 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BigintOperators.java @@ -51,6 +51,7 @@ import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.spi.StandardErrorCode.DIVISION_BY_ZERO; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToRawIntBits; @@ -270,10 +271,16 @@ public static long castToReal(@SqlType(StandardTypes.BIGINT) long value) @ScalarOperator(CAST) @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice castToVarchar(@SqlType(StandardTypes.BIGINT) long value) + public static Slice castToVarchar(@LiteralParameter("x") long x, @SqlType(StandardTypes.BIGINT) long value) { // todo optimize me - return utf8Slice(String.valueOf(value)); + String stringValue = String.valueOf(value); + // String is all-ASCII, so String.length() here returns actual code points count + if (stringValue.length() <= x) { + return utf8Slice(stringValue); + } + + throw new PrestoException(INVALID_CAST_ARGUMENT, format("Value %s cannot be represented as varchar(%s)", value, x)); } @ScalarOperator(HASH_CODE) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java index 36ceea199ff17..9497f7f08ecff 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java @@ -31,6 +31,7 @@ import com.facebook.presto.sql.analyzer.SemanticErrorCode; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -152,6 +153,11 @@ protected void assertInvalidFunction(String projection, ErrorCodeSupplier expect functionAssertions.assertInvalidFunction(projection, expectedErrorCode); } + protected void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class throwableClass, @Language("RegExp") String message) + { + functionAssertions.assertFunctionThrowsIncorrectly(projection, throwableClass, message); + } + protected void assertNumericOverflow(String projection, String message) { functionAssertions.assertNumericOverflow(projection, message); @@ -162,7 +168,7 @@ protected void assertInvalidCast(String projection) functionAssertions.assertInvalidCast(projection); } - protected void assertInvalidCast(String projection, String message) + protected void assertInvalidCast(@Language("SQL") String projection, String message) { functionAssertions.assertInvalidCast(projection, message); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index 772a330080478..2272061bd8dcf 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -88,6 +88,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import org.intellij.lang.annotations.Language; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.openjdk.jol.info.ClassLayout; @@ -148,6 +149,7 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -404,6 +406,13 @@ public void assertInvalidFunction(String projection, ErrorCodeSupplier expectedE } } + public void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class throwableClass, @Language("RegExp") String message) + { + assertThatThrownBy(() -> evaluateInvalid(projection)) + .isInstanceOf(throwableClass) + .hasMessageMatching(message); + } + public void assertNumericOverflow(String projection, String message) { try { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java index 0e74b659df158..2b6ab2fda1ebc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java @@ -83,6 +83,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; @@ -102,6 +103,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public class TestExpressionInterpreter { @@ -584,6 +586,42 @@ public void testCastToString() // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); } + @Test + public void testCastBigintToBoundedVarchar() + { + assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); + assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); + + try { + evaluate("CAST(12300000000 AS varchar(3))", true); + fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + + try { + evaluate("CAST(-12300000000 AS varchar(3))", true); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + } + @Test public void testCastToBoolean() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java index babcfa2f9df18..9d1453c577ba5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; @@ -23,6 +24,7 @@ import com.facebook.presto.sql.planner.PlanVariableAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.VariablesExtractor; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -42,6 +44,7 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.sql.ExpressionUtils.binaryExpression; import static com.facebook.presto.sql.ExpressionUtils.extractPredicates; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; @@ -51,6 +54,7 @@ import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; public class TestSimplifyExpressions { @@ -132,6 +136,45 @@ public void testExtractCommonPredicates() " OR (A51 AND A52) OR (A53 AND A54) OR (A55 AND A56) OR (A57 AND A58) OR (A59 AND A60)"); } + @Test + public void testCastBigintToBoundedVarchar() + { + // the varchar type length is enough to contain the number's representation + assertSimplifies("CAST(12300000000 AS varchar(11))", "'12300000000'"); + // The last argument "'-12300000000'" is varchar(12). Need varchar(50) to the following test pass. + //assertSimplifies("CAST(-12300000000 AS varchar(50))", "CAST('-12300000000' AS varchar(50))", "'-12300000000'"); + + // cast from bigint to varchar fails, so the expression is not modified + try { + assertSimplifies("CAST(12300000000 AS varchar(3))", "CAST(12300000000 AS varchar(3))"); + fail("Expected to throw an PrestoException exception"); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + + try { + assertSimplifies("CAST(-12300000000 AS varchar(3))", "CAST(-12300000000 AS varchar(3))"); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + } + private static void assertSimplifies(String expression, String expected) { assertSimplifies(expression, expected, null); @@ -177,5 +220,13 @@ public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, V .collect(toList()); return binaryExpression(node.getOperator(), predicates); } + + @Override + public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter treeRewriter) + { + // the `expected` Cast expression comes out of the AstBuilder with the `typeOnly` flag set to false. + // always set the `typeOnly` flag to false so that it does not break the comparison. + return new Cast(node.getExpression(), node.getType(), node.isSafe(), false); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestBigintOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestBigintOperators.java index bcb48459664d2..6d2005e47dfed 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestBigintOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestBigintOperators.java @@ -22,6 +22,7 @@ import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static java.lang.String.format; public class TestBigintOperators @@ -188,31 +189,34 @@ public void testCastToBigint() @Test public void testCastToVarchar() { - assertFunction("cast(37 as varchar)", VARCHAR, "37"); + assertFunction("cast(BIGINT '37' as varchar)", VARCHAR, "37"); assertFunction("cast(100000000017 as varchar)", VARCHAR, "100000000017"); + assertFunction("cast(100000000017 as varchar(13))", createVarcharType(13), "100000000017"); + assertFunction("cast(100000000017 as varchar(50))", createVarcharType(50), "100000000017"); + assertInvalidCast("cast(100000000017 as varchar(2))", "Value 100000000017 cannot be represented as varchar(2)"); } @Test public void testCastToDouble() { - assertFunction("cast(37 as double)", DOUBLE, 37.0); + assertFunction("cast(BIGINT '37' as double)", DOUBLE, 37.0); assertFunction("cast(100000000017 as double)", DOUBLE, 100000000017.0); } @Test public void testCastToFloat() { - assertFunction("cast(37 as real)", REAL, 37.0f); + assertFunction("cast(BIGINT '37' as real)", REAL, 37.0f); assertFunction("cast(-100000000017 as real)", REAL, -100000000017.0f); - assertFunction("cast(0 as real)", REAL, 0.0f); + assertFunction("cast(BIGINT '0' as real)", REAL, 0.0f); } @Test public void testCastToBoolean() { - assertFunction("cast(37 as boolean)", BOOLEAN, true); + assertFunction("cast(BIGINT '37' as boolean)", BOOLEAN, true); assertFunction("cast(100000000017 as boolean)", BOOLEAN, true); - assertFunction("cast(0 as boolean)", BOOLEAN, false); + assertFunction("cast(BIGINT '0' as boolean)", BOOLEAN, false); } @Test