From b5c69fafda2e99316c2a501e786d08978b6bad9b Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Wed, 24 Jul 2024 15:08:42 +0530 Subject: [PATCH] Support integral cast projection pushdown in redshift --- .../plugin/jdbc/BaseJdbcCastPushdownTest.java | 43 +- plugin/trino-redshift/pom.xml | 2 + .../trino/plugin/redshift/RedshiftClient.java | 15 + .../io/trino/plugin/redshift/RewriteCast.java | 90 ++++ .../redshift/TestRedshiftCastPushdown.java | 494 ++++++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 84 +++ 6 files changed, 723 insertions(+), 5 deletions(-) create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RewriteCast.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftCastPushdown.java diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java index d26ff33a0135..5a5d84550cfb 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java @@ -13,12 +13,15 @@ */ package io.trino.plugin.jdbc; +import io.trino.Session; import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.sql.SqlExecutor; +import org.intellij.lang.annotations.Language; import org.junit.jupiter.api.Test; import java.util.List; +import java.util.Optional; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -69,10 +72,29 @@ public void testJoinPushdownWithCast() @Test public void testInvalidCast() { - for (InvalidCastTestCase testCase : invalidCast()) { - assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) - .failure() - .hasMessageMatching(testCase.errorMessage()); + assertInvalidCast(leftTable(), invalidCast()); + } + + protected void assertInvalidCast(String tableName, List invalidCastTestCases) + { + Session withoutPushdown = Session.builder(getSession()) + .setSystemProperty("allow_pushdown_into_connectors", "false") + .build(); + + for (InvalidCastTestCase testCase : invalidCastTestCases) { + if (testCase.pushdownErrorMessage().isPresent()) { + assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName))) + .failure() + .hasMessageMatching(testCase.pushdownErrorMessage().get()); + assertThat(query(withoutPushdown, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName))) + .failure() + .hasMessageMatching(testCase.errorMessage()); + } + else { + assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), tableName))) + .failure() + .hasMessageMatching(testCase.errorMessage()); + } } } @@ -86,18 +108,29 @@ public record CastTestCase(String sourceColumn, String castType, String targetCo } } - public record InvalidCastTestCase(String sourceColumn, String castType, String errorMessage) + public record InvalidCastTestCase(String sourceColumn, String castType, String errorMessage, Optional pushdownErrorMessage) { public InvalidCastTestCase(String sourceColumn, String castType) { this(sourceColumn, castType, "(.*)Cannot cast (.*) to (.*)"); } + public InvalidCastTestCase(String sourceColumn, String castType, String errorMessage) + { + this(sourceColumn, castType, errorMessage, Optional.empty()); + } + + public InvalidCastTestCase(String sourceColumn, String castType, String errorMessage, @Language("RegExp") String pushdownErrorMessage) + { + this(sourceColumn, castType, errorMessage, Optional.of(pushdownErrorMessage)); + } + public InvalidCastTestCase { requireNonNull(sourceColumn, "sourceColumn is null"); requireNonNull(castType, "castType is null"); requireNonNull(errorMessage, "errorMessage is null"); + requireNonNull(pushdownErrorMessage, "pushdownErrorMessage is null"); } } } diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index 9fc74f7c7c52..26270fedc206 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -234,6 +234,7 @@ **/TestRedshiftAutomaticJoinPushdown.java + **/TestRedshiftCastPushdown.java **/TestRedshiftConnectorTest.java **/TestRedshiftConnectorSmokeTest.java **/TestRedshiftTableStatisticsReader.java @@ -262,6 +263,7 @@ + **/TestRedshiftCastPushdown.java **/TestRedshiftConnectorSmokeTest.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 64701f2082c4..d8fc2991eea8 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -24,6 +24,8 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.mapping.IdentifierMapping; +import io.trino.plugin.base.projection.ProjectFunctionRewriter; +import io.trino.plugin.base.projection.ProjectFunctionRule; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; @@ -222,6 +224,7 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + private final ProjectFunctionRewriter projectFunctionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; private final boolean statisticsEnabled; private final RedshiftTableStatisticsReader statisticsReader; @@ -248,6 +251,12 @@ public RedshiftClient( .map("$greater_than_or_equal(left, right)").to("left >= right") .build(); + this.projectFunctionRewriter = new ProjectFunctionRewriter<>( + connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new RewriteCast((session, type) -> toWriteMapping(session, type).getDataType())) + .build()); + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); aggregateFunctionRewriter = new AggregateFunctionRewriter<>( @@ -359,6 +368,12 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) + { + return projectFunctionRewriter.rewrite(session, handle, expression, assignments); + } + @Override public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RewriteCast.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RewriteCast.java new file mode 100644 index 000000000000..1694b5a4c8cd --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RewriteCast.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.AbstractRewriteCast; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; + +import static java.sql.Types.BIGINT; +import static java.sql.Types.BIT; +import static java.sql.Types.INTEGER; +import static java.sql.Types.NUMERIC; +import static java.sql.Types.SMALLINT; + +public class RewriteCast + extends AbstractRewriteCast +{ + private static final List SUPPORTED_SOURCE_TYPE_FOR_INTEGRAL_CAST = ImmutableList.of(BIT, SMALLINT, INTEGER, BIGINT, NUMERIC); + + public RewriteCast(BiFunction jdbcTypeProvider) + { + super(jdbcTypeProvider); + } + + @Override + protected Optional toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType) + { + if (!pushdownSupported(sourceType, targetType)) { + return Optional.empty(); + } + + return switch (targetType) { + case SmallintType smallintType -> + Optional.of(new JdbcTypeHandle(SMALLINT, Optional.of(smallintType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())); + case IntegerType integerType -> + Optional.of(new JdbcTypeHandle(INTEGER, Optional.of(integerType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())); + case BigintType bigintType -> + Optional.of(new JdbcTypeHandle(BIGINT, Optional.of(bigintType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())); + default -> Optional.empty(); + }; + } + + private boolean pushdownSupported(JdbcTypeHandle sourceType, Type targetType) + { + return switch (targetType) { + case SmallintType _, IntegerType _, BigintType _ -> + SUPPORTED_SOURCE_TYPE_FOR_INTEGRAL_CAST.contains(sourceType.jdbcType()); + default -> false; + }; + } + + @Override + protected String buildCast(Type sourceType, Type targetType, String expression, String castType) + { + if (sourceType instanceof DecimalType && isIntegralType(targetType)) { + // Trino rounds up to nearest integral value, whereas Redshift does not. + // So using ROUND() to make pushdown same as the trino behavior + return "CAST(ROUND(%s) AS %s)".formatted(expression, castType); + } + return "CAST(%s AS %s)".formatted(expression, castType); + } + + private boolean isIntegralType(Type type) + { + return type instanceof SmallintType + || type instanceof IntegerType + || type instanceof BigintType; + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftCastPushdown.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftCastPushdown.java new file mode 100644 index 000000000000..eb4ce8bcd678 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftCastPushdown.java @@ -0,0 +1,494 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.jdbc.BaseJdbcCastPushdownTest; +import io.trino.plugin.jdbc.CastDataTypeTestTable; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.SqlExecutor; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.trino.plugin.redshift.TestingRedshiftServer.TEST_SCHEMA; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestRedshiftCastPushdown + extends BaseJdbcCastPushdownTest +{ + private CastDataTypeTestTable left; + private CastDataTypeTestTable right; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return RedshiftQueryRunner.builder() + .setConnectorProperties(ImmutableMap.builder() + .put("unsupported-type-handling", "CONVERT_TO_VARCHAR") + .put("join-pushdown.enabled", "true") + .put("join-pushdown.strategy", "EAGER") + .buildOrThrow()) + .build(); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return TestingRedshiftServer::executeInRedshift; + } + + @BeforeAll + public void setupTable() + { + left = closeAfterClass(CastDataTypeTestTable.create(3) + .addColumn("id", "int", asList(11, 12, 13)) + .addColumn("c_boolean", "boolean", asList(true, false, null)) + .addColumn("c_smallint", "smallint", asList(1, 2, null)) + .addColumn("c_int2", "int2", asList(1, 2, null)) + .addColumn("c_integer", "integer", asList(1, 2, null)) + .addColumn("c_int", "int", asList(1, 2, null)) + .addColumn("c_int4", "int4", asList(1, 2, null)) + .addColumn("c_bigint", "bigint", asList(1, 2, null)) + .addColumn("c_int8", "int8", asList(1, 2, null)) + .addColumn("c_real", "real", asList(1.23, 2.67, null)) + .addColumn("c_float4", "float4", asList(1.23, 2.67, null)) + .addColumn("c_double_precision", "double precision", asList(1.23, 2.67, null)) + .addColumn("c_float", "float", asList(1.23, 2.67, null)) + .addColumn("c_float8", "float8", asList(1.23, 2.67, null)) + .addColumn("c_decimal_10_2", "decimal(10, 2)", asList(1.23, 2.67, null)) + .addColumn("c_decimal_19_2", "decimal(19, 2)", asList(1.23, 2.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION + .addColumn("c_decimal_30_2", "decimal(30, 2)", asList(1.23, 2.67, null)) + .addColumn("c_numeric_10_2", "numeric(10, 2)", asList(1.23, 2.67, null)) + .addColumn("c_numeric_19_2", "numeric(19, 2)", asList(1.23, 2.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION + .addColumn("c_numeric_30_2", "numeric(30, 2)", asList(1.23, 2.67, null)) + .addColumn("c_char_10", "char(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_char_50", "char(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_char_4096", "char(4096)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_CHAR + .addColumn("c_nchar_10", "nchar(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_nchar_50", "nchar(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_nchar_4096", "nchar(4096)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_CHAR + .addColumn("c_bpchar", "bpchar", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_10", "varchar(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_50", "varchar(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_65535", "varchar(65535)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_VARCHAR + .addColumn("c_nvarchar_10", "nvarchar(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_nvarchar_50", "nvarchar(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_nvarchar_65535", "nvarchar(65535)", asList("'India'", "'Poland'", null)) // Greater than REDSHIFT_MAX_VARCHAR + .addColumn("c_text", "text", asList("'India'", "'Poland'", null)) + .addColumn("c_varbinary", "varbinary", asList("'\\x66696E6465706920726F636B7321'", "'\\x000102f0feee'", null)) + .addColumn("c_date", "date", asList("DATE '2024-09-08'", "DATE '2019-08-15'", null)) + .addColumn("c_time", "time", asList("TIME '00:13:42.000000'", "TIME '10:01:17.100000'", null)) + .addColumn("c_timestamp", "timestamp", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2019-08-15 09:08:07.333'", null)) + .addColumn("c_timestamptz", "timestamptz", asList("TIMESTAMP '2024-09-08 01:02:03.666+05:30'", "TIMESTAMP '2019-08-15 09:08:07.333+05:30'", null)) + + .addColumn("c_nan_real", "real", asList("'Nan'", "'-Nan'", null)) + .addColumn("c_nan_double", "double precision", asList("'Nan'", "'-Nan'", null)) + .addColumn("c_infinity_real", "real", asList("'Infinity'", "'-Infinity'", null)) + .addColumn("c_infinity_double", "double precision", asList("'Infinity'", "'-Infinity'", null)) + .addColumn("c_decimal_negative", "decimal(19, 2)", asList(-1.23, -2.67, null)) + .addColumn("c_varchar_numeric", "varchar(50)", asList("'123'", "'124'", null)) + .addColumn("c_char_numeric", "char(50)", asList("'123'", "'124'", null)) + .addColumn("c_bpchar_numeric", "bpchar", asList("'123'", "'124'", null)) + .addColumn("c_text_numeric", "text", asList("'123'", "'124'", null)) + .addColumn("c_nvarchar_numeric", "nvarchar(50)", asList("'123'", "'124'", null)) + .addColumn("c_varchar_numeric_sign", "varchar(50)", asList("'+123'", "'-124'", null)) + .addColumn("c_varchar_decimal", "varchar(50)", asList("'1.23'", "'2.34'", null)) + .addColumn("c_varchar_decimal_sign", "varchar(50)", asList("'+1.23'", "'-2.34'", null)) + .addColumn("c_varchar_alpha_numeric", "varchar(50)", asList("'H311o'", "'123Hey'", null)) + .addColumn("c_varchar_date", "varchar(50)", asList("'2024-09-08'", "'2019-08-15'", null)) + .addColumn("c_varchar_timestamp", "varchar(50)", asList("'2024-09-08 01:02:03.666'", "'2019-08-15 09:08:07.333'", null)) + .addColumn("c_varchar_timestamptz", "varchar(50)", asList("'2024-09-08 01:02:03.666+05:30'", "'2019-08-15 09:08:07.333+05:30'", null)) + + // unsupported in trino + .addColumn("c_timetz", "timetz", asList("TIME '00:13:42.000000+05:30'", "TIME '10:01:17.100000+05:30'", null)) + .addColumn("c_super", "super", asList(1, 2, null)) + .execute(onRemoteDatabase(), TEST_SCHEMA + "." + "left_table_")); + + // 2nd row value is different in right than left + right = closeAfterClass(CastDataTypeTestTable.create(3) + .addColumn("id", "int", asList(21, 22, 23)) + .addColumn("c_boolean", "boolean", asList(true, true, null)) + .addColumn("c_smallint", "smallint", asList(1, 22, null)) + .addColumn("c_int2", "int2", asList(1, 22, null)) + .addColumn("c_integer", "integer", asList(1, 22, null)) + .addColumn("c_int", "int", asList(1, 22, null)) + .addColumn("c_int4", "int4", asList(1, 22, null)) + .addColumn("c_bigint", "bigint", asList(1, 22, null)) + .addColumn("c_int8", "int8", asList(1, 22, null)) + .addColumn("c_real", "real", asList(1.23, 22.67, null)) + .addColumn("c_float4", "float4", asList(1.23, 22.67, null)) + .addColumn("c_double_precision", "double precision", asList(1.23, 22.67, null)) + .addColumn("c_float", "float", asList(1.23, 22.67, null)) + .addColumn("c_float8", "float8", asList(1.23, 22.67, null)) + .addColumn("c_decimal_10_2", "decimal(10, 2)", asList(1.23, 22.67, null)) + .addColumn("c_decimal_19_2", "decimal(19, 2)", asList(1.23, 22.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION + .addColumn("c_decimal_30_2", "decimal(30, 2)", asList(1.23, 22.67, null)) + .addColumn("c_numeric_10_2", "numeric(10, 2)", asList(1.23, 22.67, null)) + .addColumn("c_numeric_19_2", "numeric(19, 2)", asList(1.23, 22.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION + .addColumn("c_numeric_30_2", "numeric(30, 2)", asList(1.23, 22.67, null)) + .addColumn("c_char_10", "char(10)", asList("'India'", "'France'", null)) + .addColumn("c_char_50", "char(50)", asList("'India'", "'France'", null)) + .addColumn("c_char_4096", "char(4096)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_CHAR + .addColumn("c_nchar_10", "nchar(10)", asList("'India'", "'France'", null)) + .addColumn("c_nchar_50", "nchar(50)", asList("'India'", "'France'", null)) + .addColumn("c_nchar_4096", "nchar(4096)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_CHAR + .addColumn("c_bpchar", "bpchar", asList("'India'", "'France'", null)) + .addColumn("c_varchar_10", "varchar(10)", asList("'India'", "'France'", null)) + .addColumn("c_varchar_50", "varchar(50)", asList("'India'", "'France'", null)) + .addColumn("c_varchar_65535", "varchar(65535)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_VARCHAR + .addColumn("c_nvarchar_10", "nvarchar(10)", asList("'India'", "'France'", null)) + .addColumn("c_nvarchar_50", "nvarchar(50)", asList("'India'", "'France'", null)) + .addColumn("c_nvarchar_65535", "nvarchar(65535)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_VARCHAR + .addColumn("c_text", "text", asList("'India'", "'France'", null)) + .addColumn("c_varbinary", "varbinary", asList("'\\x66696E6465706920726F636B7321'", "'\\x4672616E6365'", null)) + .addColumn("c_date", "date", asList("DATE '2024-09-08'", "DATE '2020-08-15'", null)) + .addColumn("c_time", "time", asList("TIME '00:13:42.000000'", "TIME '11:01:17.100000'", null)) + .addColumn("c_timestamp", "timestamp", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2020-08-15 09:08:07.333'", null)) + .addColumn("c_timestamptz", "timestamptz", asList("TIMESTAMP '2024-09-08 01:02:03.666+05:30'", "TIMESTAMP '2020-08-15 09:08:07.333+05:30'", null)) + + .addColumn("c_nan_real", "real", asList("'Nan'", "'-Nan'", null)) + .addColumn("c_nan_double", "double precision", asList("'Nan'", "'-Nan'", null)) + .addColumn("c_infinity_real", "real", asList("'Infinity'", "'-Infinity'", null)) + .addColumn("c_infinity_double", "double precision", asList("'Infinity'", "'-Infinity'", null)) + .addColumn("c_decimal_negative", "decimal(19, 2)", asList(-1.23, -22.67, null)) + .addColumn("c_varchar_numeric", "varchar(50)", asList("'123'", "'228'", null)) + .addColumn("c_char_numeric", "char(50)", asList("'123'", "'125'", null)) + .addColumn("c_bpchar_numeric", "bpchar", asList("'123'", "'125'", null)) + .addColumn("c_text_numeric", "text", asList("'123'", "'125'", null)) + .addColumn("c_nvarchar_numeric", "nvarchar(50)", asList("'123'", "'125'", null)) + .addColumn("c_varchar_numeric_sign", "varchar(50)", asList("'+123'", "'-125'", null)) + .addColumn("c_varchar_decimal", "varchar(50)", asList("'1.23'", "'22.34'", null)) + .addColumn("c_varchar_decimal_sign", "varchar(50)", asList("'+1.23'", "'-22.34'", null)) + .addColumn("c_varchar_alpha_numeric", "varchar(50)", asList("'H311o'", "'123Bye'", null)) + .addColumn("c_varchar_date", "varchar(50)", asList("'2024-09-08'", "'2020-08-15'", null)) + .addColumn("c_varchar_timestamp", "varchar(50)", asList("'2024-09-08 01:02:03.666'", "'2020-08-15 09:08:07.333'", null)) + .addColumn("c_varchar_timestamptz", "varchar(50)", asList("'2024-09-08 01:02:03.666+05:30'", "'2020-08-15 09:08:07.333+05:30'", null)) + + // unsupported in trino + .addColumn("c_timetz", "timetz", asList("TIME '00:13:42.000000+05:30'", "TIME '11:01:17.100000+05:30'", null)) + .addColumn("c_super", "super", asList(1, 22, null)) + .execute(onRemoteDatabase(), TEST_SCHEMA + "." + "right_table_")); + } + + @Override + protected String leftTable() + { + return left.getName(); + } + + @Override + protected String rightTable() + { + return right.getName(); + } + + @Test + public void testJoinPushdownWithNestedCast() + { + CastTestCase testCase = new CastTestCase("c_decimal_10_2", "bigint", "c_bigint"); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(CAST(l.%s AS %s) AS integer) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + } + + @Test + public void testAllJoinPushdownWithCast() + { + CastTestCase testCase = new CastTestCase("c_int", "bigint", "c_bigint"); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .joinIsNotFullyPushedDown(); + + testCase = new CastTestCase("c_bigint", "integer", "c_int"); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn(), testCase.castType()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn(), testCase.castType()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn(), testCase.castType()))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .joinIsNotFullyPushedDown(); + + testCase = new CastTestCase("c_bigint", "smallint", "c_int"); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn()))) + .joinIsNotFullyPushedDown(); + } + + @Test + public void testCastPushdownDisabled() + { + Session sessionWithoutPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "complex_expression_pushdown", "false") + .build(); + assertThat(query(sessionWithoutPushdown, "SELECT CAST (c_int AS bigint) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + } + + @Test + public void testCastPushdownOutOfRangeValue() + { + CastDataTypeTestTable table = closeAfterClass(CastDataTypeTestTable.create(1) + .addColumn("id", "int", List.of(1)) + .addColumn("c_smallint_1", "smallint", List.of("-129")) + .addColumn("c_smallint_2", "smallint", List.of("128")) + .addColumn("c_int_1", "integer", List.of("-65537")) + .addColumn("c_int_2", "integer", List.of("65536")) + .addColumn("c_bigint_1", "bigint", List.of("-2147483649")) + .addColumn("c_bigint_2", "bigint", List.of("2147483648")) + + .addColumn("c_decimal_1", "decimal(25, 2)", List.of("-129.49")) + .addColumn("c_decimal_2", "decimal(25, 2)", List.of("-128.94")) + .addColumn("c_decimal_3", "decimal(25, 2)", List.of("127.94")) + .addColumn("c_decimal_4", "decimal(25, 2)", List.of("128.49")) + .addColumn("c_decimal_5", "decimal(25, 2)", List.of("-65537.49")) + .addColumn("c_decimal_6", "decimal(25, 2)", List.of("-65536.94")) + .addColumn("c_decimal_7", "decimal(25, 2)", List.of("65535.94")) + .addColumn("c_decimal_8", "decimal(25, 2)", List.of("65536.49")) + .addColumn("c_decimal_9", "decimal(25, 2)", List.of("-2147483649.49")) + .addColumn("c_decimal_10", "decimal(25, 2)", List.of("-2147483648.94")) + .addColumn("c_decimal_11", "decimal(25, 2)", List.of("2147483647.94")) + .addColumn("c_decimal_12", "decimal(25, 2)", List.of("2147483648.49")) + .addColumn("c_decimal_13", "decimal(25, 2)", List.of("-9223372036854775809.49")) + .addColumn("c_decimal_14", "decimal(25, 2)", List.of("-9223372036854775808.94")) + .addColumn("c_decimal_15", "decimal(25, 2)", List.of("9223372036854775807.94")) + .addColumn("c_decimal_16", "decimal(25, 2)", List.of("9223372036854775808.49")) + + .addColumn("c_infinity_real_1", "real", List.of("'Infinity'")) + .addColumn("c_infinity_real_2", "real", List.of("'-Infinity'")) + .addColumn("c_infinity_double_1", "double precision", List.of("'Infinity'")) + .addColumn("c_infinity_double_2", "double precision", List.of("'-Infinity'")) + .execute(onRemoteDatabase(), TEST_SCHEMA + "." + "test_decimal_overflow_")); + + assertInvalidCast( + table.getName(), + ImmutableList.of( + // Not pushdown for tinyint type + new InvalidCastTestCase("c_smallint_1", "tinyint", "Out of range for tinyint: -129"), + new InvalidCastTestCase("c_smallint_2", "tinyint", "Out of range for tinyint: 128"), + new InvalidCastTestCase("c_int_1", "tinyint", "Out of range for tinyint: -65537"), + new InvalidCastTestCase("c_int_2", "tinyint", "Out of range for tinyint: 65536"), + new InvalidCastTestCase("c_bigint_1", "tinyint", "Out of range for tinyint: -2147483649"), + new InvalidCastTestCase("c_bigint_2", "tinyint", "Out of range for tinyint: 2147483648"), + + new InvalidCastTestCase("c_int_1", "smallint", "Out of range for smallint: -65537", "(?s).*ERROR: Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_int_2", "smallint", "Out of range for smallint: 65536", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_bigint_1", "smallint", "Out of range for smallint: -2147483649", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_bigint_2", "smallint", "Out of range for smallint: 2147483648", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_bigint_1", "integer", "Out of range for integer: -2147483649", "(?s).*Value out of range for 4 bytes(?s).*"), + new InvalidCastTestCase("c_bigint_2", "integer", "Out of range for integer: 2147483648", "(?s).*Value out of range for 4 bytes(?s).*"), + + // Not pushdown for tinyint type + new InvalidCastTestCase("c_decimal_1", "tinyint", "Cannot cast '-129.49' to TINYINT"), + new InvalidCastTestCase("c_decimal_2", "tinyint", "Cannot cast '-128.94' to TINYINT"), + new InvalidCastTestCase("c_decimal_3", "tinyint", "Cannot cast '127.94' to TINYINT"), + new InvalidCastTestCase("c_decimal_4", "tinyint", "Cannot cast '128.49' to TINYINT"), + + new InvalidCastTestCase("c_decimal_5", "smallint", "Cannot cast '-65537.49' to SMALLINT", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_6", "smallint", "Cannot cast '-65536.94' to SMALLINT", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_7", "smallint", "Cannot cast '65535.94' to SMALLINT", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_8", "smallint", "Cannot cast '65536.49' to SMALLINT", "(?s).*Value out of range for 2 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_9", "integer", "Cannot cast '-2147483649.49' to INTEGER", "(?s).*Value out of range for 4 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_10", "integer", "Cannot cast '-2147483648.94' to INTEGER", "(?s).*Value out of range for 4 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_11", "integer", "Cannot cast '2147483647.94' to INTEGER", "(?s).*Value out of range for 4 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_12", "integer", "Cannot cast '2147483648.49' to INTEGER", "(?s).*Value out of range for 4 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_13", "bigint", "Cannot cast '-9223372036854775809.49' to BIGINT", "(?s).*Value out of range for 8 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_14", "bigint", "Cannot cast '-9223372036854775808.94' to BIGINT", "(?s).*Value out of range for 8 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_15", "bigint", "Cannot cast '9223372036854775807.94' to BIGINT", "(?s).*Value out of range for 8 bytes(?s).*"), + new InvalidCastTestCase("c_decimal_16", "bigint", "Cannot cast '9223372036854775808.49' to BIGINT", "(?s).*Value out of range for 8 bytes(?s).*"), + + // No pushdown for real datatype to integral types + new InvalidCastTestCase("c_infinity_real_1", "tinyint", "Out of range for tinyint: Infinity"), + new InvalidCastTestCase("c_infinity_real_1", "smallint", "Out of range for smallint: Infinity"), + new InvalidCastTestCase("c_infinity_real_1", "integer", "Out of range for integer: Infinity"), + new InvalidCastTestCase("c_infinity_real_2", "tinyint", "Out of range for tinyint: -Infinity"), + new InvalidCastTestCase("c_infinity_real_2", "smallint", "Out of range for smallint: -Infinity"), + new InvalidCastTestCase("c_infinity_real_2", "integer", "Out of range for integer: -Infinity"), + + // No pushdown for double precision datatype to integral types + new InvalidCastTestCase("c_infinity_double_1", "tinyint", "Out of range for tinyint: Infinity"), + new InvalidCastTestCase("c_infinity_double_1", "smallint", "Out of range for smallint: Infinity"), + new InvalidCastTestCase("c_infinity_double_1", "integer", "Out of range for integer: Infinity"), + new InvalidCastTestCase("c_infinity_double_1", "bigint", "Unable to cast Infinity to bigint"), + new InvalidCastTestCase("c_infinity_double_2", "tinyint", "Out of range for tinyint: -Infinity"), + new InvalidCastTestCase("c_infinity_double_2", "smallint", "Out of range for smallint: -Infinity"), + new InvalidCastTestCase("c_infinity_double_2", "integer", "Out of range for integer: -Infinity"), + new InvalidCastTestCase("c_infinity_double_2", "bigint", "Unable to cast -Infinity to bigint"))); + } + + @Test + public void testCastRealInfinityValueToBigint() + { + assertThat(query("SELECT CAST(c_Infinity_real AS BIGINT) FROM %s".formatted(leftTable()))) + .matches("VALUES (BIGINT '9223372036854775807'), (BIGINT '-9223372036854775808'), (null)") + .isNotFullyPushedDown(ProjectNode.class); + } + + @Test + public void testCastPushdownWithForcedTypedToInteger() + { + // These column types are not supported by default by trino. These types are forced mapped to varchar. + assertThat(query("SELECT CAST(c_super AS INTEGER) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + assertThat(query("SELECT CAST(c_super AS BIGINT) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + } + + @Override + protected List supportedCastTypePushdown() + { + return ImmutableList.of( + new CastTestCase("c_boolean", "smallint", "c_smallint"), + new CastTestCase("c_smallint", "smallint", "c_smallint"), + new CastTestCase("c_integer", "smallint", "c_smallint"), + new CastTestCase("c_bigint", "smallint", "c_smallint"), + new CastTestCase("c_decimal_10_2", "smallint", "c_smallint"), + new CastTestCase("c_decimal_19_2", "smallint", "c_smallint"), + new CastTestCase("c_decimal_30_2", "smallint", "c_smallint"), + new CastTestCase("c_decimal_negative", "smallint", "c_smallint"), + + new CastTestCase("c_boolean", "integer", "c_integer"), + new CastTestCase("c_smallint", "integer", "c_integer"), + new CastTestCase("c_int2", "integer", "c_integer"), + new CastTestCase("c_integer", "integer", "c_integer"), + new CastTestCase("c_int", "integer", "c_integer"), + new CastTestCase("c_int4", "integer", "c_integer"), + new CastTestCase("c_bigint", "integer", "c_integer"), + new CastTestCase("c_int8", "integer", "c_integer"), + new CastTestCase("c_decimal_10_2", "integer", "c_integer"), + new CastTestCase("c_decimal_19_2", "integer", "c_integer"), + new CastTestCase("c_decimal_30_2", "integer", "c_integer"), + new CastTestCase("c_numeric_10_2", "integer", "c_integer"), + new CastTestCase("c_numeric_19_2", "integer", "c_integer"), + new CastTestCase("c_numeric_30_2", "integer", "c_integer"), + new CastTestCase("c_decimal_negative", "integer", "c_integer"), + + new CastTestCase("c_boolean", "bigint", "c_bigint"), + new CastTestCase("c_smallint", "bigint", "c_bigint"), + new CastTestCase("c_integer", "bigint", "c_bigint"), + new CastTestCase("c_bigint", "bigint", "c_bigint"), + new CastTestCase("c_decimal_10_2", "bigint", "c_bigint"), + new CastTestCase("c_decimal_19_2", "bigint", "c_bigint"), + new CastTestCase("c_decimal_30_2", "bigint", "c_bigint"), + new CastTestCase("c_decimal_negative", "bigint", "c_bigint")); + } + + @Override + protected List unsupportedCastTypePushdown() + { + return ImmutableList.of( + new CastTestCase("c_boolean", "tinyint", "c_smallint"), + new CastTestCase("c_smallint", "tinyint", "c_smallint"), + new CastTestCase("c_integer", "tinyint", "c_smallint"), + new CastTestCase("c_bigint", "tinyint", "c_smallint"), + new CastTestCase("c_decimal_10_2", "tinyint", "c_smallint"), + new CastTestCase("c_decimal_19_2", "tinyint", "c_smallint"), + new CastTestCase("c_decimal_30_2", "tinyint", "c_smallint"), + new CastTestCase("c_decimal_negative", "tinyint", "c_smallint"), + + new CastTestCase("c_real", "tinyint", "c_smallint"), + new CastTestCase("c_float4", "tinyint", "c_smallint"), + new CastTestCase("c_double_precision", "tinyint", "c_smallint"), + new CastTestCase("c_float", "tinyint", "c_smallint"), + new CastTestCase("c_float8", "tinyint", "c_smallint"), + new CastTestCase("c_varchar_numeric", "tinyint", "c_smallint"), + new CastTestCase("c_text_numeric", "tinyint", "c_smallint"), + new CastTestCase("c_nvarchar_numeric", "tinyint", "c_smallint"), + new CastTestCase("c_varchar_numeric_sign", "tinyint", "c_smallint"), + + new CastTestCase("c_real", "smallint", "c_smallint"), + new CastTestCase("c_float4", "smallint", "c_smallint"), + new CastTestCase("c_double_precision", "smallint", "c_smallint"), + new CastTestCase("c_float", "smallint", "c_smallint"), + new CastTestCase("c_float8", "smallint", "c_smallint"), + new CastTestCase("c_varchar_numeric", "smallint", "c_smallint"), + new CastTestCase("c_text_numeric", "smallint", "c_smallint"), + new CastTestCase("c_nvarchar_numeric", "smallint", "c_smallint"), + new CastTestCase("c_varchar_numeric_sign", "smallint", "c_smallint"), + + new CastTestCase("c_real", "integer", "c_integer"), + new CastTestCase("c_float4", "integer", "c_integer"), + new CastTestCase("c_double_precision", "integer", "c_integer"), + new CastTestCase("c_float", "integer", "c_integer"), + new CastTestCase("c_float8", "integer", "c_integer"), + new CastTestCase("c_varchar_numeric", "integer", "c_integer"), + new CastTestCase("c_text_numeric", "integer", "c_integer"), + new CastTestCase("c_nvarchar_numeric", "integer", "c_integer"), + new CastTestCase("c_varchar_numeric_sign", "integer", "c_integer"), + + new CastTestCase("c_real", "bigint", "c_bigint"), + new CastTestCase("c_float4", "bigint", "c_bigint"), + new CastTestCase("c_double_precision", "bigint", "c_bigint"), + new CastTestCase("c_float", "bigint", "c_bigint"), + new CastTestCase("c_float8", "bigint", "c_bigint"), + new CastTestCase("c_varchar_numeric", "bigint", "c_bigint"), + new CastTestCase("c_text_numeric", "bigint", "c_bigint"), + new CastTestCase("c_nvarchar_numeric", "bigint", "c_bigint"), + new CastTestCase("c_varchar_numeric_sign", "bigint", "c_bigint"), + + new CastTestCase("c_smallint", "boolean", "c_boolean"), + new CastTestCase("c_real", "double", "c_double_precision"), + new CastTestCase("c_double_precision", "real", "c_real"), + new CastTestCase("c_double_precision", "decimal(10,2)", "c_decimal_10_2"), + new CastTestCase("c_char_10", "char(50)", "c_char_50"), + new CastTestCase("c_char_10", "char(256)", "c_bpchar"), + new CastTestCase("c_varchar_10", "varchar(50)", "c_varchar_50"), + new CastTestCase("c_nvarchar_10", "varchar(50)", "c_nvarchar_50"), + new CastTestCase("c_varchar_10", "varchar(50)", "c_text"), + new CastTestCase("c_timestamp", "date", "c_date"), + new CastTestCase("c_timestamp", "time", "c_time"), + new CastTestCase("c_date", "timestamp", "c_timestamp"), + new CastTestCase("c_varchar_timestamp", "timestamp", "c_timestamp"), + new CastTestCase("c_varchar_timestamptz", "timestamp", "c_timestamp")); + } + + @Override + protected List invalidCast() + { + return ImmutableList.of( + new InvalidCastTestCase("c_varchar_decimal", "integer"), + new InvalidCastTestCase("c_varchar_decimal_sign", "integer"), + new InvalidCastTestCase("c_varchar_alpha_numeric", "integer"), + new InvalidCastTestCase("c_char_50", "integer"), + new InvalidCastTestCase("c_char_numeric", "integer"), + new InvalidCastTestCase("c_bpchar_numeric", "integer"), + new InvalidCastTestCase("c_nan_real", "integer"), + new InvalidCastTestCase("c_nan_double", "integer"), + + // c_timetz is not supported by default by trino. This is forced mapped to varchar. + new InvalidCastTestCase("c_timetz", "tinyint"), + new InvalidCastTestCase("c_timetz", "smallint"), + new InvalidCastTestCase("c_timetz", "int"), + new InvalidCastTestCase("c_timetz", "bigint")); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 996e2547f083..2b47cf8b7780 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -692,6 +692,90 @@ public void testInsertRowConcurrently() abort("Test fails with a timeout sometimes and is flaky"); } + @Test + public void testJoinPushdownWithImplicitCast() + { + try (TestTable leftTable = new TestTable( + getQueryRunner()::execute, + "left_table_", + "(id int, c_boolean boolean, c_tinyint tinyint, c_smallint smallint, c_integer integer, c_bigint bigint, c_real real, c_double_precision double precision, c_decimal_10_2 decimal(10, 2))", + ImmutableList.of( + "(11, true, 12, 12, 12, 12, 12.34, 12.34, 12.34)", + "(12, false, 123, 123, 123, 123, 123.67, 123.67, 123.67)")); + TestTable rightTable = new TestTable( + getQueryRunner()::execute, + "right_table_", + "(id int, c_boolean boolean, c_tinyint tinyint, c_smallint smallint, c_integer integer, c_bigint bigint, c_real real, c_double_precision double precision, c_decimal_10_2 decimal(10, 2))", + ImmutableList.of( + "(21, true, 12, 12, 12, 12, 12.34, 12.34, 12.34)", + "(22, true, 234, 234, 234, 234, 234.67, 234.67, 234.67)"))) { + Session session = joinPushdownEnabled(getSession()); + String joinQuery = "SELECT l.id FROM " + leftTable.getName() + " l %s " + rightTable.getName() + " r ON %s"; + + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_tinyint = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_tinyint = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_tinyint = r.c_bigint"))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_tinyint = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_smallint = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_smallint = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_smallint = r.c_bigint"))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_smallint = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_integer = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_integer = r.c_bigint"))) + .isFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_integer = r.c_bigint"))) + .isFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_integer = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + + // Below cases try to implicit cast from bigint type to real/double/decimal type. + // CAST pushdown with real/double/decimal type is not supported yet. + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_real = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_real = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_real = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_real = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_double_precision = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_double_precision = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_double_precision = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_double_precision = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + + assertThat(query(session,joinQuery.formatted("LEFT JOIN", "l.c_decimal_10_2 = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("RIGHT JOIN", "l.c_decimal_10_2 = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session,joinQuery.formatted("INNER JOIN", "l.c_decimal_10_2 = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + // Full Join pushdown is not supported + assertThat(query(session,joinQuery.formatted("FULL JOIN", "l.c_decimal_10_2 = r.c_bigint"))) + .joinIsNotFullyPushedDown(); + } + } + @Override protected Session joinPushdownEnabled(Session session) {