diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index b24d6d56366d..a68c9e1dbe1f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -32,7 +32,6 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.spi.type.VarcharType; import io.trino.sql.DynamicFilters; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.TypeSignatureProvider; @@ -79,6 +78,7 @@ import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; @@ -97,7 +97,10 @@ import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -225,9 +228,13 @@ protected Optional translateCall(Call call) if (IS_NULL_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { return translateIsNull(call.getArguments().get(0)); } + if (NULLIF_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) { return translateNullIf(call.getArguments().get(0), call.getArguments().get(1)); } + if (CAST_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { + return translateCast(call.getType(), call.getArguments().get(0)); + } // comparisons if (call.getArguments().size() == 2) { @@ -303,6 +310,16 @@ private Optional translateNot(ConnectorExpression argument) if (argument.getType().equals(BOOLEAN) && translatedArgument.isPresent()) { return Optional.of(new NotExpression(translatedArgument.get())); } + return Optional.empty(); + } + + private Optional translateCast(Type type, ConnectorExpression expression) + { + Optional translatedExpression = translate(expression); + + if (translatedExpression.isPresent()) { + return Optional.of(new Cast(translatedExpression.get(), toSqlType(type))); + } return Optional.empty(); } @@ -573,6 +590,22 @@ protected Optional visitCast(Cast node, Void context) if (isEffectivelyLiteral(plannerContext, session, node)) { return Optional.of(constantFor(node)); } + + if (node.isSafe()) { + // try_cast would need to be modeled separately + return Optional.empty(); + } + + if (!isComplexExpressionPushdown(session)) { + return Optional.empty(); + } + + Optional translatedExpression = process(node.getExpression()); + if (translatedExpression.isPresent()) { + Type type = plannerContext.getTypeManager().getType(toTypeSignature(node.getType())); + return Optional.of(new Call(type, CAST_FUNCTION_NAME, List.of(translatedExpression.get()))); + } + return Optional.empty(); } @@ -598,11 +631,11 @@ protected Optional visitFunctionCall(FunctionCall node, Voi Object value = evaluateConstant(node); if (value instanceof JoniRegexp) { Slice pattern = ((JoniRegexp) value).pattern(); - return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern)))); + return Optional.of(new Constant(pattern, createVarcharType(countCodePoints(pattern)))); } if (value instanceof Re2JRegexp) { Slice pattern = Slices.utf8Slice(((Re2JRegexp) value).pattern()); - return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern)))); + return Optional.of(new Constant(pattern, createVarcharType(countCodePoints(pattern)))); } return Optional.of(new Constant(value, types.get(NodeRef.of(node)))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index 8d65ed12d75a..d13f4e928259 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -25,9 +25,11 @@ import io.trino.spi.expression.StandardFunctions; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.BetweenPredicate; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; @@ -55,6 +57,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; @@ -73,6 +76,7 @@ import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; @@ -86,7 +90,7 @@ public class TestConnectorExpressionTranslator private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build(); private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))); - private static final Type VARCHAR_TYPE = createVarcharType(25); + private static final VarcharType VARCHAR_TYPE = createVarcharType(25); private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); private static final Map symbols = ImmutableMap.builder() @@ -330,6 +334,37 @@ public void testTranslateNullIf() new Variable("varchar_symbol_1", VARCHAR_TYPE)))); } + @Test + public void testTranslateCast() + { + assertTranslationRoundTrips( + new Cast(new SymbolReference("varchar_symbol_1"), toSqlType(VARCHAR_TYPE)), + new Call( + VARCHAR_TYPE, + CAST_FUNCTION_NAME, + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + + // type-only + VarcharType longerVarchar = createVarcharType(VARCHAR_TYPE.getBoundedLength() + 1); + assertTranslationToConnectorExpression( + TEST_SESSION, + new Cast(new SymbolReference("varchar_symbol_1"), toSqlType(longerVarchar), false, true), + new Call( + longerVarchar, + CAST_FUNCTION_NAME, + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + + // TRY_CAST is not translated + assertTranslationToConnectorExpression( + TEST_SESSION, + new Cast( + new SymbolReference("varchar_symbol_1"), + toSqlType(BIGINT), + true, + true), + Optional.empty()); + } + @Test public void testTranslateResolvedFunction() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java index 67f582ea50df..fa64db21371a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanRedirectionWithPushdown.java @@ -32,6 +32,7 @@ import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableScanRedirectApplicationResult; +import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.Variable; @@ -53,6 +54,7 @@ import static io.trino.connector.MockConnectorFactory.ApplyFilter; import static io.trino.connector.MockConnectorFactory.ApplyProjection; import static io.trino.connector.MockConnectorFactory.ApplyTableScanRedirect; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; @@ -366,9 +368,12 @@ private Optional> mockApplyPro for (ConnectorExpression projection : projections) { String newVariableName; + ConnectorExpression newVariable; ColumnHandle newColumnHandle; + Type type = projection.getType(); if (projection instanceof Variable) { newVariableName = ((Variable) projection).getName(); + newVariable = new Variable(newVariableName, type); newColumnHandle = assignments.get(newVariableName); } else if (projection instanceof FieldDereference) { @@ -378,16 +383,27 @@ else if (projection instanceof FieldDereference) { } String dereferenceTargetName = ((Variable) dereference.getTarget()).getName(); newVariableName = ((MockConnectorColumnHandle) assignments.get(dereferenceTargetName)).getName() + "#" + dereference.getField(); - newColumnHandle = new MockConnectorColumnHandle(newVariableName, projection.getType()); + newVariable = new Variable(newVariableName, type); + newColumnHandle = new MockConnectorColumnHandle(newVariableName, type); + } + else if (projection instanceof Call) { + Call call = (Call) projection; + if (!(CAST_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1)) { + throw new UnsupportedOperationException(); + } + // Avoid CAST pushdown into the connector + newVariableName = ((Variable) call.getArguments().get(0)).getName(); + newVariable = projection; + newColumnHandle = assignments.get(newVariableName); + type = call.getArguments().get(0).getType(); } else { throw new UnsupportedOperationException(); } - Variable newVariable = new Variable(newVariableName, projection.getType()); newColumnsBuilder.add(newColumnHandle); outputExpressions.add(newVariable); - outputAssignments.add(new Assignment(newVariableName, newColumnHandle, projection.getType())); + outputAssignments.add(new Assignment(newVariableName, newColumnHandle, type)); } List newColumns = newColumnsBuilder.build(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index bf5679fb442a..cd7694efb7c1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -38,6 +38,11 @@ private StandardFunctions() {} */ public static final FunctionName NULLIF_FUNCTION_NAME = new FunctionName("$nullif"); + /** + * $cast function result type is determined by the {@link Call#getType()} + */ + public static final FunctionName CAST_FUNCTION_NAME = new FunctionName("$cast"); + public static final FunctionName EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$equal"); public static final FunctionName NOT_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$not_equal"); public static final FunctionName LESS_THAN_OPERATOR_FUNCTION_NAME = new FunctionName("$less_than");