Skip to content

Commit

Permalink
Support aggregation pushdown in MySQL connector
Browse files Browse the repository at this point in the history
  • Loading branch information
yuokada authored and findepi committed Aug 26, 2020
1 parent 31adff9 commit 193872a
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 0 deletions.
5 changes: 5 additions & 0 deletions presto-mysql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
<artifactId>presto-base-jdbc</artifactId>
</dependency>

<dependency>
<groupId>io.prestosql</groupId>
<artifactId>presto-matching</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>configuration</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.plugin.mysql;

import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.prestosql.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;

public class ImplementAvgBigint
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignments().get(input.getName());
verifyNotNull(columnHandle, "Unbound variable: %s", input);
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg((%s * 1.0))", columnHandle.toSqlExpression(context.getIdentifierQuote())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty(), Optional.empty())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,27 @@
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.PredicatePushdownController;
import io.prestosql.plugin.jdbc.WriteMapping;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRewriter;
import io.prestosql.plugin.jdbc.expression.AggregateFunctionRule;
import io.prestosql.plugin.jdbc.expression.ImplementAvgDecimal;
import io.prestosql.plugin.jdbc.expression.ImplementAvgFloatingPoint;
import io.prestosql.plugin.jdbc.expression.ImplementCount;
import io.prestosql.plugin.jdbc.expression.ImplementCountAll;
import io.prestosql.plugin.jdbc.expression.ImplementMinMax;
import io.prestosql.plugin.jdbc.expression.ImplementSum;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.ConnectorTableMetadata;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.Type;
Expand All @@ -56,6 +68,7 @@
import java.sql.Types;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

Expand Down Expand Up @@ -104,12 +117,38 @@ public class MySqlClient
extends BaseJdbcClient
{
private final Type jsonType;
private final AggregateFunctionRewriter aggregateFunctionRewriter;

@Inject
public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, TypeManager typeManager)
{
super(config, "`", connectionFactory);
this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON));

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), 0, 0, Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter(
this::quoted,
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementSum(MySqlClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
.add(new ImplementAvgBigint())
.build());
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
// TODO support complex ConnectorExpressions
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), decimalType.getPrecision(), decimalType.getScale(), Optional.empty(), Optional.empty()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.plugin.mysql;

import io.prestosql.plugin.jdbc.BaseJdbcConfig;
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.JdbcClient;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcExpression;
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.spi.type.TypeManager;
import io.prestosql.type.InternalTypeManager;
import org.testng.annotations.Test;

import java.sql.Types;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.testing.TestingConnectorSession.SESSION;
import static io.prestosql.testing.assertions.Assert.assertEquals;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertTrue;

public class TestMySqlClient
{
private static final TypeManager TYPE_MANAGER = new InternalTypeManager(createTestMetadataManager());

private static final JdbcColumnHandle BIGINT_COLUMN =
JdbcColumnHandle.builder()
.setColumnName("c_bigint")
.setColumnType(BIGINT)
.setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), 0, 0, Optional.empty()))
.build();

private static final JdbcColumnHandle DOUBLE_COLUMN =
JdbcColumnHandle.builder()
.setColumnName("c_double")
.setColumnType(DOUBLE)
.setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), 0, 0, Optional.empty()))
.build();

private static final JdbcClient JDBC_CLIENT = new MySqlClient(
new BaseJdbcConfig(),
identity -> {
throw new UnsupportedOperationException();
},
TYPE_MANAGER);

@Test
public void testImplementCount()
{
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", BIGINT);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));

// count(*)
testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()),
Map.of(),
Optional.of("count(*)"));

// count(bigint)
testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.of("count(`c_bigint`)"));

// count(double)
testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()),
Map.of(doubleVariable.getName(), DOUBLE_COLUMN),
Optional.of("count(`c_double`)"));

// count(DISTINCT bigint)
testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.empty());

// count() FILTER (WHERE ...)

testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter),
Map.of(),
Optional.empty());

// count(bigint) FILTER (WHERE ...)
testImplementAggregation(
new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.empty());
}

@Test
public void testImplementSum()
{
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", DOUBLE);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));

// sum(bigint)
testImplementAggregation(
new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.of("sum(`c_bigint`)"));

// sum(double)
testImplementAggregation(
new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()),
Map.of(doubleVariable.getName(), DOUBLE_COLUMN),
Optional.of("sum(`c_double`)"));

// sum(DISTINCT bigint)
testImplementAggregation(
new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.empty()); // distinct not supported

// sum(bigint) FILTER (WHERE ...)
testImplementAggregation(
new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter),
Map.of(bigintVariable.getName(), BIGINT_COLUMN),
Optional.empty()); // filter not supported
}

private void testImplementAggregation(AggregateFunction aggregateFunction, Map<String, ColumnHandle> assignments, Optional<String> expectedExpression)
{
Optional<JdbcExpression> result = JDBC_CLIENT.implementAggregation(SESSION, aggregateFunction, assignments);
if (expectedExpression.isEmpty()) {
assertThat(result).isEmpty();
}
else {
assertThat(result).isPresent();
assertEquals(result.get().getExpression(), expectedExpression.get());
Optional<ColumnMapping> columnMapping = JDBC_CLIENT.toPrestoType(SESSION, null, result.get().getJdbcTypeHandle());
assertTrue(columnMapping.isPresent(), "No mapping for: " + result.get().getJdbcTypeHandle());
assertEquals(columnMapping.get().getType(), aggregateFunction.getOutputType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,36 @@ public void testInsertIntoNotNullColumn()
assertUpdate("DROP TABLE test_insert_not_null");
}

@Test
public void testAggregationPushdown()
throws Exception
{
// TODO support aggregation pushdown with GROUPING SETS
// TODO support aggregation over expressions

assertThat(query("SELECT count(*) FROM nation")).isCorrectlyPushedDown();
assertThat(query("SELECT count(nationkey) FROM nation")).isCorrectlyPushedDown();
assertThat(query("SELECT count(1) FROM nation")).isCorrectlyPushedDown();
assertThat(query("SELECT count() FROM nation")).isCorrectlyPushedDown();
assertThat(query("SELECT regionkey, min(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
assertThat(query("SELECT regionkey, max(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();
assertThat(query("SELECT regionkey, avg(nationkey) FROM nation GROUP BY regionkey")).isCorrectlyPushedDown();

assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 GROUP BY regionkey")).isCorrectlyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class);

try (AutoCloseable ignoreTable = withTable("tpch.test_aggregation_pushdown", "(short_decimal decimal(9, 3), long_decimal decimal(30, 10))")) {
execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (100.000, 100000000.000000000)");
execute("INSERT INTO tpch.test_aggregation_pushdown VALUES (123.321, 123456789.987654321)");

assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
assertThat(query("SELECT avg(short_decimal), avg(long_decimal) FROM test_aggregation_pushdown")).isCorrectlyPushedDown();
}
}

@Test
public void testColumnComment()
{
Expand Down Expand Up @@ -272,6 +302,20 @@ public void testPredicatePushdown()
.isCorrectlyPushedDown();
}

private AutoCloseable withTable(String tableName, String tableDefinition)
throws Exception
{
execute(format("CREATE TABLE %s%s", tableName, tableDefinition));
return () -> {
try {
execute(format("DROP TABLE %s", tableName));
}
catch (RuntimeException e) {
throw new RuntimeException(e);
}
};
}

private void execute(String sql)
{
mysqlServer.execute(sql, mysqlServer.getUsername(), mysqlServer.getPassword());
Expand Down

0 comments on commit 193872a

Please sign in to comment.