diff --git a/plugin/trino-clickhouse/pom.xml b/plugin/trino-clickhouse/pom.xml index d29125cb94a8..d24f8a46fa32 100644 --- a/plugin/trino-clickhouse/pom.xml +++ b/plugin/trino-clickhouse/pom.xml @@ -23,11 +23,6 @@ trino-base-jdbc - - io.trino - trino-matching - - io.trino trino-plugin-toolkit diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index ccee07d51af4..7ab40471d762 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -162,7 +162,6 @@ public ClickHouseClient( .add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved .add(new ImplementSum(ClickHouseClient::toTypeHandle)) .add(new ImplementAvgFloatingPoint()) - .add(new ImplementAvgDecimal()) .add(new ImplementAvgBigint()) .build()); } diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java deleted file mode 100644 index 4f16368a47da..000000000000 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.clickhouse; - -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.jdbc.JdbcColumnHandle; -import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.spi.connector.AggregateFunction; -import io.trino.spi.expression.Variable; -import io.trino.spi.type.DecimalType; - -import java.util.Optional; - -import static com.google.common.base.Verify.verify; -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; -import static io.trino.plugin.clickhouse.ClickHouseClient.CLICKHOUSE_MAX_DECIMAL_PRECISION; -import static java.lang.String.format; - -/** - * Implements {@code avg(decimal(p, s)} - */ -public class ImplementAvgDecimal - implements AggregateFunctionRule -{ - private static final Capture INPUT = newCapture(); - - @Override - public Pattern getPattern() - { - return basicAggregation() - .with(functionName().equalTo("avg")) - .with(singleInput().matching( - variable() - .with(expressionType().matching(DecimalType.class::isInstance)) - .capturedAs(INPUT))); - } - - @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) - { - Variable input = captures.get(INPUT); - JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); - DecimalType type = (DecimalType) columnHandle.getColumnType(); - verify(aggregateFunction.getOutputType().equals(type)); - - // When decimal type has maximum precision we can get result that does not match Trino avg semantics. - if (type.getPrecision() == CLICKHOUSE_MAX_DECIMAL_PRECISION) { - return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS decimal(%s, %s)))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()), - columnHandle.getJdbcTypeHandle())); - } - - // ClickHouse avg function rounds down resulting decimal. - // To match Trino avg semantics, we extend scale by 1 and round result to target scale. - return Optional.of(new JdbcExpression( - format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), - columnHandle.getJdbcTypeHandle())); - } -} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index 51a28ebd1f32..368c0412d6e3 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.sql.planner.plan.AggregationNode; import io.trino.testing.MaterializedResult; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -432,7 +433,8 @@ public void testNumericAggregationPushdown() assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown(); assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown(); assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown(); - assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM " + testTable.getName())).isNotFullyPushedDown(AggregationNode.class); } }