diff --git a/presto-mysql/pom.xml b/presto-mysql/pom.xml index 1bb3ca454711..7786cf0e2b35 100644 --- a/presto-mysql/pom.xml +++ b/presto-mysql/pom.xml @@ -22,6 +22,11 @@ presto-base-jdbc + + io.prestosql + presto-matching + + io.airlift configuration diff --git a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java new file mode 100644 index 000000000000..29ced9fe0b49 --- /dev/null +++ b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/ImplementAvgBigint.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.mysql; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.plugin.jdbc.JdbcColumnHandle; +import io.prestosql.plugin.jdbc.JdbcExpression; +import io.prestosql.plugin.jdbc.JdbcTypeHandle; +import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.expression.Variable; + +import java.sql.Types; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation; +import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType; +import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.functionName; +import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput; +import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.variable; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static java.lang.String.format; + +public class ImplementAvgBigint + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleInput().matching( + variable() + .with(expressionType().matching(type -> type == BIGINT)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignments().get(input.getName()); + verifyNotNull(columnHandle, "Unbound variable: %s", input); + verify(aggregateFunction.getOutputType() == DOUBLE); + + return Optional.of(new JdbcExpression( + format("avg((%s * 1.0))", columnHandle.toSqlExpression(context.getIdentifierQuote())), + new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty(), Optional.empty()))); + } +} diff --git a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java index d7e5fda03726..ad7c15a011a9 100644 --- a/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java +++ b/presto-mysql/src/main/java/io/prestosql/plugin/mysql/MySqlClient.java @@ -27,15 +27,27 @@ import io.prestosql.plugin.jdbc.ColumnMapping; import io.prestosql.plugin.jdbc.ConnectionFactory; import io.prestosql.plugin.jdbc.JdbcColumnHandle; +import io.prestosql.plugin.jdbc.JdbcExpression; import io.prestosql.plugin.jdbc.JdbcIdentity; import io.prestosql.plugin.jdbc.JdbcTableHandle; import io.prestosql.plugin.jdbc.JdbcTypeHandle; import io.prestosql.plugin.jdbc.PredicatePushdownController; import io.prestosql.plugin.jdbc.WriteMapping; +import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter; +import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule; +import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal; +import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint; +import io.prestosql.plugin.jdbc.expression.ImplementCount; +import io.prestosql.plugin.jdbc.expression.ImplementCountAll; +import io.prestosql.plugin.jdbc.expression.ImplementMinMax; +import io.prestosql.plugin.jdbc.expression.ImplementSum; import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorTableMetadata; import io.prestosql.spi.connector.SchemaTableName; +import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.Decimals; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; @@ -56,6 +68,7 @@ import java.sql.Types; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; @@ -104,12 +117,38 @@ public class MySqlClient extends BaseJdbcClient { private final Type jsonType; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, TypeManager typeManager) { super(config, "`", connectionFactory); this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), 0, 0, Optional.empty(), Optional.empty()); + this.aggregateFunctionRewriter = new AggregateFunctionRewriter( + this::quoted, + ImmutableSet.builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementMinMax()) + .add(new ImplementSum(MySqlClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementAvgDecimal()) + .add(new ImplementAvgBigint()) + .build()); + } + + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + // TODO support complex ConnectorExpressions + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + } + + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty())); } @Override diff --git a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java new file mode 100644 index 000000000000..df1aa3635c55 --- /dev/null +++ b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlClient.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.mysql; + +import io.prestosql.plugin.jdbc.BaseJdbcConfig; +import io.prestosql.plugin.jdbc.ColumnMapping; +import io.prestosql.plugin.jdbc.JdbcClient; +import io.prestosql.plugin.jdbc.JdbcColumnHandle; +import io.prestosql.plugin.jdbc.JdbcExpression; +import io.prestosql.plugin.jdbc.JdbcTypeHandle; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.expression.Variable; +import io.prestosql.spi.type.TypeManager; +import io.prestosql.type.InternalTypeManager; +import org.testng.annotations.Test; + +import java.sql.Types; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.testing.TestingConnectorSession.SESSION; +import static io.prestosql.testing.assertions.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertTrue; + +public class TestMySqlClient +{ + private static final TypeManager TYPE_MANAGER = new InternalTypeManager(createTestMetadataManager()); + + private static final JdbcColumnHandle BIGINT_COLUMN = + JdbcColumnHandle.builder() + .setColumnName("c_bigint") + .setColumnType(BIGINT) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), 0, 0, Optional.empty())) + .build(); + + private static final JdbcColumnHandle DOUBLE_COLUMN = + JdbcColumnHandle.builder() + .setColumnName("c_double") + .setColumnType(DOUBLE) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty())) + .build(); + + private static final JdbcClient JDBC_CLIENT = new MySqlClient( + new BaseJdbcConfig(), + identity -> { + throw new UnsupportedOperationException(); + }, + TYPE_MANAGER); + + @Test + public void testImplementCount() + { + Variable bigintVariable = new Variable("v_bigint", BIGINT); + Variable doubleVariable = new Variable("v_double", BIGINT); + Optional filter = Optional.of(new Variable("a_filter", BOOLEAN)); + + // count(*) + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()), + Map.of(), + Optional.of("count(*)")); + + // count(bigint) + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.of("count(`c_bigint`)")); + + // count(double) + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()), + Map.of(doubleVariable.getName(), DOUBLE_COLUMN), + Optional.of("count(`c_double`)")); + + // count(DISTINCT bigint) + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.empty()); + + // count() FILTER (WHERE ...) + + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter), + Map.of(), + Optional.empty()); + + // count(bigint) FILTER (WHERE ...) + testImplementAggregation( + new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.empty()); + } + + @Test + public void testImplementSum() + { + Variable bigintVariable = new Variable("v_bigint", BIGINT); + Variable doubleVariable = new Variable("v_double", DOUBLE); + Optional filter = Optional.of(new Variable("a_filter", BOOLEAN)); + + // sum(bigint) + testImplementAggregation( + new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.of("sum(`c_bigint`)")); + + // sum(double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()), + Map.of(doubleVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(`c_double`)")); + + // sum(DISTINCT bigint) + testImplementAggregation( + new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.empty()); // distinct not supported + + // sum(bigint) FILTER (WHERE ...) + testImplementAggregation( + new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter), + Map.of(bigintVariable.getName(), BIGINT_COLUMN), + Optional.empty()); // filter not supported + } + + private void testImplementAggregation(AggregateFunction aggregateFunction, Map assignments, Optional expectedExpression) + { + Optional result = JDBC_CLIENT.implementAggregation(SESSION, aggregateFunction, assignments); + if (expectedExpression.isEmpty()) { + assertThat(result).isEmpty(); + } + else { + assertThat(result).isPresent(); + assertEquals(result.get().getExpression(), expectedExpression.get()); + Optional columnMapping = JDBC_CLIENT.toPrestoType(SESSION, null, result.get().getJdbcTypeHandle()); + assertTrue(columnMapping.isPresent(), "No mapping for: " + result.get().getJdbcTypeHandle()); + assertEquals(columnMapping.get().getType(), aggregateFunction.getOutputType()); + } + } +} diff --git a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java index 0b6b8f9969f7..03906c94b38e 100644 --- a/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java +++ b/presto-mysql/src/test/java/io/prestosql/plugin/mysql/TestMySqlIntegrationSmokeTest.java @@ -226,6 +226,36 @@ public void testInsertIntoNotNullColumn() assertUpdate("DROP TABLE test_insert_not_null"); } + @Test + public void testAggregationPushdown() + throws Exception + { + // TODO support aggregation pushdown with GROUPING SETS + // TODO support aggregation over expressions + + assertThat(query("SELECT count(*) FROM nation")).isCorrectlyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM nation")).isCorrectlyPushedDown(); + assertThat(query("SELECT count(1) FROM nation")).isCorrectlyPushedDown(); + assertThat(query("SELECT count() FROM nation")).isCorrectlyPushedDown(); + assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown(); + + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isCorrectlyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class); + + try (AutoCloseable ignoreTable = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) { + execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (100.000, 100000000.000000000)"); + execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (123.321, 123456789.987654321)"); + + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown(); + } + } + @Test public void testColumnComment() { @@ -272,6 +302,20 @@ public void testPredicatePushdown() .isCorrectlyPushedDown(); } + private AutoCloseable withTable(String tableName, String tableDefinition) + throws Exception + { + execute(format("CREATE TABLE %s%s", tableName, tableDefinition)); + return () -> { + try { + execute(format("DROP TABLE %s", tableName)); + } + catch (RuntimeException e) { + throw new RuntimeException(e); + } + }; + } + private void execute(String sql) { mysqlServer.execute(sql, mysqlServer.getUsername(), mysqlServer.getPassword());