diff --git a/presto-mysql/pom.xml b/presto-mysql/pom.xml
index 1bb3ca454711..7786cf0e2b35 100644
--- a/presto-mysql/pom.xml
+++ b/presto-mysql/pom.xml
@@ -22,6 +22,11 @@
presto-base-jdbc
+
+ io.prestosql
+ presto-matching
+
+
io.airlift
configuration
diff --git a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java
new file mode 100644
index 000000000000..29ced9fe0b49
--- /dev/null
+++ b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.prestosql.plugin.mysql;
+
+import io.prestosql.matching.Capture;
+import io.prestosql.matching.Captures;
+import io.prestosql.matching.Pattern;
+import io.prestosql.plugin.jdbc.JdbcColumnHandle;
+import io.prestosql.plugin.jdbc.JdbcExpression;
+import io.prestosql.plugin.jdbc.JdbcTypeHandle;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
+import io.prestosql.spi.connector.AggregateFunction;
+import io.prestosql.spi.expression.Variable;
+
+import java.sql.Types;
+import java.util.Optional;
+
+import static com.google.common.base.Verify.verify;
+import static com.google.common.base.Verify.verifyNotNull;
+import static io.prestosql.matching.Capture.newCapture;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
+import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
+import static io.prestosql.spi.type.BigintType.BIGINT;
+import static io.prestosql.spi.type.DoubleType.DOUBLE;
+import static java.lang.String.format;
+
+public class ImplementAvgBigint
+ implements AggregateFunctionRule
+{
+ private static final Capture INPUT = newCapture();
+
+ @Override
+ public Pattern getPattern()
+ {
+ return basicAggregation()
+ .with(functionName().equalTo("avg"))
+ .with(singleInput().matching(
+ variable()
+ .with(expressionType().matching(type -> type == BIGINT))
+ .capturedAs(INPUT)));
+ }
+
+ @Override
+ public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
+ {
+ Variable input = captures.get(INPUT);
+ JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignments().get(input.getName());
+ verifyNotNull(columnHandle, "Unbound variable: %s", input);
+ verify(aggregateFunction.getOutputType() == DOUBLE);
+
+ return Optional.of(new JdbcExpression(
+ format("avg((%s * 1.0))", columnHandle.toSqlExpression(context.getIdentifierQuote())),
+ new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty(), Optional.empty())));
+ }
+}
diff --git a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java
index d7e5fda03726..ad7c15a011a9 100644
--- a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java
+++ b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java
@@ -27,15 +27,27 @@
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
+import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.PredicatePushdownController;
import io.prestosql.plugin.jdbc.WriteMapping;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter;
+import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
+import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal;
+import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint;
+import io.prestosql.plugin.jdbc.expression.ImplementCount;
+import io.prestosql.plugin.jdbc.expression.ImplementCountAll;
+import io.prestosql.plugin.jdbc.expression.ImplementMinMax;
+import io.prestosql.plugin.jdbc.expression.ImplementSum;
import io.prestosql.spi.PrestoException;
+import io.prestosql.spi.connector.AggregateFunction;
+import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.ConnectorTableMetadata;
import io.prestosql.spi.connector.SchemaTableName;
+import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.Type;
@@ -56,6 +68,7 @@
import java.sql.Types;
import java.util.Collection;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
@@ -104,12 +117,38 @@ public class MySqlClient
extends BaseJdbcClient
{
private final Type jsonType;
+ private final AggregateFunctionRewriter aggregateFunctionRewriter;
@Inject
public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, TypeManager typeManager)
{
super(config, "`", connectionFactory);
this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON));
+
+ JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), 0, 0, Optional.empty(), Optional.empty());
+ this.aggregateFunctionRewriter = new AggregateFunctionRewriter(
+ this::quoted,
+ ImmutableSet.builder()
+ .add(new ImplementCountAll(bigintTypeHandle))
+ .add(new ImplementCount(bigintTypeHandle))
+ .add(new ImplementMinMax())
+ .add(new ImplementSum(MySqlClient::toTypeHandle))
+ .add(new ImplementAvgFloatingPoint())
+ .add(new ImplementAvgDecimal())
+ .add(new ImplementAvgBigint())
+ .build());
+ }
+
+ @Override
+ public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments)
+ {
+ // TODO support complex ConnectorExpressions
+ return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
+ }
+
+ private static Optional toTypeHandle(DecimalType decimalType)
+ {
+ return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty()));
}
@Override
diff --git a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java
new file mode 100644
index 000000000000..df1aa3635c55
--- /dev/null
+++ b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.prestosql.plugin.mysql;
+
+import io.prestosql.plugin.jdbc.BaseJdbcConfig;
+import io.prestosql.plugin.jdbc.ColumnMapping;
+import io.prestosql.plugin.jdbc.JdbcClient;
+import io.prestosql.plugin.jdbc.JdbcColumnHandle;
+import io.prestosql.plugin.jdbc.JdbcExpression;
+import io.prestosql.plugin.jdbc.JdbcTypeHandle;
+import io.prestosql.spi.connector.AggregateFunction;
+import io.prestosql.spi.connector.ColumnHandle;
+import io.prestosql.spi.expression.ConnectorExpression;
+import io.prestosql.spi.expression.Variable;
+import io.prestosql.spi.type.TypeManager;
+import io.prestosql.type.InternalTypeManager;
+import org.testng.annotations.Test;
+
+import java.sql.Types;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
+import static io.prestosql.spi.type.BigintType.BIGINT;
+import static io.prestosql.spi.type.BooleanType.BOOLEAN;
+import static io.prestosql.spi.type.DoubleType.DOUBLE;
+import static io.prestosql.testing.TestingConnectorSession.SESSION;
+import static io.prestosql.testing.assertions.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.testng.Assert.assertTrue;
+
+public class TestMySqlClient
+{
+ private static final TypeManager TYPE_MANAGER = new InternalTypeManager(createTestMetadataManager());
+
+ private static final JdbcColumnHandle BIGINT_COLUMN =
+ JdbcColumnHandle.builder()
+ .setColumnName("c_bigint")
+ .setColumnType(BIGINT)
+ .setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), 0, 0, Optional.empty()))
+ .build();
+
+ private static final JdbcColumnHandle DOUBLE_COLUMN =
+ JdbcColumnHandle.builder()
+ .setColumnName("c_double")
+ .setColumnType(DOUBLE)
+ .setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty()))
+ .build();
+
+ private static final JdbcClient JDBC_CLIENT = new MySqlClient(
+ new BaseJdbcConfig(),
+ identity -> {
+ throw new UnsupportedOperationException();
+ },
+ TYPE_MANAGER);
+
+ @Test
+ public void testImplementCount()
+ {
+ Variable bigintVariable = new Variable("v_bigint", BIGINT);
+ Variable doubleVariable = new Variable("v_double", BIGINT);
+ Optional filter = Optional.of(new Variable("a_filter", BOOLEAN));
+
+ // count(*)
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()),
+ Map.of(),
+ Optional.of("count(*)"));
+
+ // count(bigint)
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.of("count(`c_bigint`)"));
+
+ // count(double)
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()),
+ Map.of(doubleVariable.getName(), DOUBLE_COLUMN),
+ Optional.of("count(`c_double`)"));
+
+ // count(DISTINCT bigint)
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.empty());
+
+ // count() FILTER (WHERE ...)
+
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter),
+ Map.of(),
+ Optional.empty());
+
+ // count(bigint) FILTER (WHERE ...)
+ testImplementAggregation(
+ new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.empty());
+ }
+
+ @Test
+ public void testImplementSum()
+ {
+ Variable bigintVariable = new Variable("v_bigint", BIGINT);
+ Variable doubleVariable = new Variable("v_double", DOUBLE);
+ Optional filter = Optional.of(new Variable("a_filter", BOOLEAN));
+
+ // sum(bigint)
+ testImplementAggregation(
+ new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.of("sum(`c_bigint`)"));
+
+ // sum(double)
+ testImplementAggregation(
+ new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()),
+ Map.of(doubleVariable.getName(), DOUBLE_COLUMN),
+ Optional.of("sum(`c_double`)"));
+
+ // sum(DISTINCT bigint)
+ testImplementAggregation(
+ new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.empty()); // distinct not supported
+
+ // sum(bigint) FILTER (WHERE ...)
+ testImplementAggregation(
+ new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter),
+ Map.of(bigintVariable.getName(), BIGINT_COLUMN),
+ Optional.empty()); // filter not supported
+ }
+
+ private void testImplementAggregation(AggregateFunction aggregateFunction, Map assignments, Optional expectedExpression)
+ {
+ Optional result = JDBC_CLIENT.implementAggregation(SESSION, aggregateFunction, assignments);
+ if (expectedExpression.isEmpty()) {
+ assertThat(result).isEmpty();
+ }
+ else {
+ assertThat(result).isPresent();
+ assertEquals(result.get().getExpression(), expectedExpression.get());
+ Optional columnMapping = JDBC_CLIENT.toPrestoType(SESSION, null, result.get().getJdbcTypeHandle());
+ assertTrue(columnMapping.isPresent(), "No mapping for: " + result.get().getJdbcTypeHandle());
+ assertEquals(columnMapping.get().getType(), aggregateFunction.getOutputType());
+ }
+ }
+}
diff --git a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java
index 0b6b8f9969f7..03906c94b38e 100644
--- a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java
+++ b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java
@@ -226,6 +226,36 @@ public void testInsertIntoNotNullColumn()
assertUpdate("DROP TABLE test_insert_not_null");
}
+ @Test
+ public void testAggregationPushdown()
+ throws Exception
+ {
+ // TODO support aggregation pushdown with GROUPING SETS
+ // TODO support aggregation over expressions
+
+ assertThat(query("SELECT count(*) FROM nation")).isCorrectlyPushedDown();
+ assertThat(query("SELECT count(nationkey) FROM nation")).isCorrectlyPushedDown();
+ assertThat(query("SELECT count(1) FROM nation")).isCorrectlyPushedDown();
+ assertThat(query("SELECT count() FROM nation")).isCorrectlyPushedDown();
+ assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
+ assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
+ assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
+ assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
+
+ assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isCorrectlyPushedDown();
+ assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class);
+
+ try (AutoCloseable ignoreTable = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) {
+ execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (100.000, 100000000.000000000)");
+ execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (123.321, 123456789.987654321)");
+
+ assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
+ assertThat(query("SELECT max(short_decimal), max(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
+ assertThat(query("SELECT sum(short_decimal), sum(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
+ assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
+ }
+ }
+
@Test
public void testColumnComment()
{
@@ -272,6 +302,20 @@ public void testPredicatePushdown()
.isCorrectlyPushedDown();
}
+ private AutoCloseable withTable(String tableName, String tableDefinition)
+ throws Exception
+ {
+ execute(format("CREATE TABLE %s%s", tableName, tableDefinition));
+ return () -> {
+ try {
+ execute(format("DROP TABLE %s", tableName));
+ }
+ catch (RuntimeException e) {
+ throw new RuntimeException(e);
+ }
+ };
+ }
+
private void execute(String sql)
{
mysqlServer.execute(sql, mysqlServer.getUsername(), mysqlServer.getPassword());