diff --git a/plugin/trino-clickhouse/pom.xml b/plugin/trino-clickhouse/pom.xml
index d29125cb94a8..d24f8a46fa32 100644
--- a/plugin/trino-clickhouse/pom.xml
+++ b/plugin/trino-clickhouse/pom.xml
@@ -23,11 +23,6 @@
trino-base-jdbc
-
- io.trino
- trino-matching
-
-
io.trino
trino-plugin-toolkit
diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java
index ccee07d51af4..7ab40471d762 100644
--- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java
+++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java
@@ -162,7 +162,6 @@ public ClickHouseClient(
.add(new ImplementMinMax(false)) // TODO: Revisit once https://github.com/trinodb/trino/issues/7100 is resolved
.add(new ImplementSum(ClickHouseClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
- .add(new ImplementAvgDecimal())
.add(new ImplementAvgBigint())
.build());
}
diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java
deleted file mode 100644
index 4f16368a47da..000000000000
--- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgDecimal.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.trino.plugin.clickhouse;
-
-import io.trino.matching.Capture;
-import io.trino.matching.Captures;
-import io.trino.matching.Pattern;
-import io.trino.plugin.base.expression.AggregateFunctionRule;
-import io.trino.plugin.jdbc.JdbcColumnHandle;
-import io.trino.plugin.jdbc.JdbcExpression;
-import io.trino.spi.connector.AggregateFunction;
-import io.trino.spi.expression.Variable;
-import io.trino.spi.type.DecimalType;
-
-import java.util.Optional;
-
-import static com.google.common.base.Verify.verify;
-import static io.trino.matching.Capture.newCapture;
-import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
-import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
-import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
-import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
-import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
-import static io.trino.plugin.clickhouse.ClickHouseClient.CLICKHOUSE_MAX_DECIMAL_PRECISION;
-import static java.lang.String.format;
-
-/**
- * Implements {@code avg(decimal(p, s)}
- */
-public class ImplementAvgDecimal
- implements AggregateFunctionRule
-{
- private static final Capture INPUT = newCapture();
-
- @Override
- public Pattern getPattern()
- {
- return basicAggregation()
- .with(functionName().equalTo("avg"))
- .with(singleInput().matching(
- variable()
- .with(expressionType().matching(DecimalType.class::isInstance))
- .capturedAs(INPUT)));
- }
-
- @Override
- public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
- {
- Variable input = captures.get(INPUT);
- JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
- DecimalType type = (DecimalType) columnHandle.getColumnType();
- verify(aggregateFunction.getOutputType().equals(type));
-
- // When decimal type has maximum precision we can get result that does not match Trino avg semantics.
- if (type.getPrecision() == CLICKHOUSE_MAX_DECIMAL_PRECISION) {
- return Optional.of(new JdbcExpression(
- format("avg(CAST(%s AS decimal(%s, %s)))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()),
- columnHandle.getJdbcTypeHandle()));
- }
-
- // ClickHouse avg function rounds down resulting decimal.
- // To match Trino avg semantics, we extend scale by 1 and round result to target scale.
- return Optional.of(new JdbcExpression(
- format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision() + 1, type.getScale() + 1, type.getScale()),
- columnHandle.getJdbcTypeHandle()));
- }
-}
diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java
index 51a28ebd1f32..368c0412d6e3 100644
--- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java
+++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java
@@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
+import io.trino.sql.planner.plan.AggregationNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
@@ -432,7 +433,8 @@ public void testNumericAggregationPushdown()
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown();
- assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
+ assertThat(query("SELECT avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isFullyPushedDown();
+ assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM " + testTable.getName())).isNotFullyPushedDown(AggregationNode.class);
}
}