diff --git a/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/ImplementAvgBigint.java b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/ImplementAvgBigint.java new file mode 100644 index 000000000000..fcdc14d3261f --- /dev/null +++ b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/ImplementAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.snowflake; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg((%s * 1.0))"; + } +} diff --git a/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java index e7ece220bad8..fb3661c45566 100644 --- a/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java +++ b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java @@ -22,6 +22,7 @@ import io.trino.plugin.base.mapping.IdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; +import io.trino.plugin.jdbc.CaseSensitivity; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; @@ -82,6 +83,7 @@ import java.util.Calendar; import java.util.Date; import java.util.GregorianCalendar; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -89,6 +91,8 @@ import java.util.function.BiFunction; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE; +import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; @@ -147,6 +151,7 @@ public SnowflakeClient( .add(new ImplementSum(SnowflakeClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) + .add(new ImplementAvgBigint()) .build()); } @@ -183,7 +188,7 @@ public Optional toColumnMapping(ConnectorSession session, Connect .put("timestampntz", handle -> Optional.of(timestampColumnMapping(handle.getRequiredDecimalDigits()))) .put("timestamptz", handle -> Optional.of(timestampTZColumnMapping(handle.getRequiredDecimalDigits()))) .put("date", handle -> Optional.of(ColumnMapping.longMapping(DateType.DATE, (resultSet, columnIndex) -> LocalDate.ofEpochDay(resultSet.getLong(columnIndex)).toEpochDay(), snowFlakeDateWriter()))) - .put("varchar", handle -> Optional.of(varcharColumnMapping(handle.getRequiredColumnSize()))) + .put("varchar", handle -> Optional.of(varcharColumnMapping(handle.getRequiredColumnSize(), typeHandle.getCaseSensitivity()))) .put("number", handle -> { int decimalDigits = handle.getRequiredDecimalDigits(); int precision = handle.getRequiredColumnSize() + Math.max(-decimalDigits, 0); @@ -254,6 +259,13 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHandle table, List aggregates, Map assignments, List> groupingSets) + { + // Remote database can be case insensitive. + return preventTextualTypeAggregationPushdown(groupingSets); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); @@ -413,13 +425,10 @@ public void set(PreparedStatement statement, int index, long picosOfDay) }; } - private static ColumnMapping varcharColumnMapping(int varcharLength) + private static ColumnMapping varcharColumnMapping(int varcharLength, Optional caseSensitivity) { VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH ? createVarcharType(varcharLength) : createUnboundedVarcharType(); - return ColumnMapping.sliceMapping( - varcharType, - StandardColumnMappings.varcharReadFunction(varcharType), - StandardColumnMappings.varcharWriteFunction()); + return StandardColumnMappings.varcharColumnMapping(varcharType, caseSensitivity.orElse(CASE_INSENSITIVE) == CASE_SENSITIVE); } private static ObjectWriteFunction longTimestampWithTzWriteFunction() diff --git a/plugin/trino-snowflake/src/test/java/io/trino/plugin/snowflake/TestSnowflakeConnectorTest.java b/plugin/trino-snowflake/src/test/java/io/trino/plugin/snowflake/TestSnowflakeConnectorTest.java index c5d7d46764f3..bb43e6d80d8d 100644 --- a/plugin/trino-snowflake/src/test/java/io/trino/plugin/snowflake/TestSnowflakeConnectorTest.java +++ b/plugin/trino-snowflake/src/test/java/io/trino/plugin/snowflake/TestSnowflakeConnectorTest.java @@ -66,14 +66,20 @@ protected SqlExecutor onRemoteDatabase() protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { return switch (connectorBehavior) { + case SUPPORTS_AGGREGATION_PUSHDOWN -> true; case SUPPORTS_ADD_COLUMN_WITH_COMMENT, - SUPPORTS_AGGREGATION_PUSHDOWN, + SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION, + SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT, + SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE, + SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, + SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV, + SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE, SUPPORTS_ARRAY, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_COMMENT_ON_TABLE, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT, - SUPPORTS_LIMIT_PUSHDOWN, + SUPPORTS_PREDICATE_PUSHDOWN, SUPPORTS_ROW_TYPE, SUPPORTS_SET_COLUMN_TYPE, SUPPORTS_TOPN_PUSHDOWN -> false; @@ -210,13 +216,6 @@ public void testCountDistinctWithStringTypes() abort("TODO"); } - @Test - @Override - public void testAggregationPushdown() - { - abort("TODO"); - } - @Test @Override public void testDistinctAggregationPushdown()