Skip to content

Commit

Permalink
Move ImplementAvgBigint logic from connectors into base jdbc module
Browse files Browse the repository at this point in the history
  • Loading branch information
grantatspothero authored and hashhar committed Nov 10, 2021
1 parent 1f781db commit 1b5db4a
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 181 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;

/**
* Implements {@code avg(x)} for bigint columns while preserving Trino semantics.
* Trino semantics say the output should be a double but pushing down the aggregation to some databases
* can result in rounding of the output to a bigint.
*/
public abstract class BaseImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
{
private final Capture<Variable> input;

public BaseImplementAvgBigint()
{
this.input = newCapture();
}

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(this.input)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(this.input);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

String columnName = context.getIdentifierQuote().apply(columnHandle.getColumnName());

return Optional.of(new JdbcExpression(
format(getRewriteFormatExpression(), columnName),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
}

/**
* Implement this method for each connector supporting avg(bigint) pushdown
* @return A format string expression with a single placeholder for the column name; The string expression pushes down avg to the remote database
*/
protected abstract String getRewriteFormatExpression();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,14 @@
*/
package io.trino.plugin.clickhouse;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint;

public class ImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
extends BaseImplementAvgBigint
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
protected String getRewriteFormatExpression()
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
return "avg((%s * 1.0))";
}
}
5 changes: 0 additions & 5 deletions plugin/trino-mysql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@
<artifactId>trino-base-jdbc</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-matching</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-plugin-toolkit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,14 @@
*/
package io.trino.plugin.mysql;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint;

public class ImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
extends BaseImplementAvgBigint
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
protected String getRewriteFormatExpression()
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
return "avg((%s * 1.0))";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ public class OracleClient

private final boolean synonymsEnabled;

/**
* Note the type mappings from trino -> oracle types can cause surprises since they are not invertible
* For example, creating an oracle table in trino with a bigint column will generate an oracle table with a number column
* Then querying the oracle table with the number column will return a decimal (not a bigint)
*/
private static final Map<Type, WriteMapping> WRITE_MAPPINGS = ImmutableMap.<Type, WriteMapping>builder()
.put(BOOLEAN, oracleBooleanWriteMapping())
.put(BIGINT, WriteMapping.longMapping("number(19)", bigintWriteFunction()))
Expand Down
5 changes: 0 additions & 5 deletions plugin/trino-postgresql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@
<artifactId>trino-base-jdbc</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-matching</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-plugin-toolkit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,14 @@
*/
package io.trino.plugin.postgresql;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint;

public class ImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
extends BaseImplementAvgBigint
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
protected String getRewriteFormatExpression()
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
return "avg(CAST(%s AS double precision))";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,14 @@
*/
package io.trino.plugin.sqlserver;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;
import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint;

public class ImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
extends BaseImplementAvgBigint
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
protected String getRewriteFormatExpression()
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
return "avg(CAST(%s AS double precision))";
}
}

0 comments on commit 1b5db4a

Please sign in to comment.