diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/BaseImplementAvgBigint.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/BaseImplementAvgBigint.java new file mode 100644 index 000000000000..0acbf203e7b5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/BaseImplementAvgBigint.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; + +import java.sql.Types; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; +import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static java.lang.String.format; + +/** + * Implements {@code avg(x)} for bigint columns while preserving Trino semantics. + * Trino semantics say the output should be a double but pushing down the aggregation to some databases + * can result in rounding of the output to a bigint. + */ +public abstract class BaseImplementAvgBigint + implements AggregateFunctionRule +{ + private final Capture input; + + public BaseImplementAvgBigint() + { + this.input = newCapture(); + } + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleInput().matching( + variable() + .with(expressionType().matching(type -> type == BIGINT)) + .capturedAs(this.input))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(this.input); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + verify(aggregateFunction.getOutputType() == DOUBLE); + + String columnName = context.getIdentifierQuote().apply(columnHandle.getColumnName()); + + return Optional.of(new JdbcExpression( + format(getRewriteFormatExpression(), columnName), + new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); + } + + /** + * Implement this method for each connector supporting avg(bigint) pushdown + * @return A format string expression with a single placeholder for the column name; The string expression pushes down avg to the remote database + */ + protected abstract String getRewriteFormatExpression(); +} diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgBigint.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgBigint.java index be2ffd465221..38be958a782a 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgBigint.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ImplementAvgBigint.java @@ -13,52 +13,14 @@ */ package io.trino.plugin.clickhouse; -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.jdbc.JdbcColumnHandle; -import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.plugin.jdbc.JdbcTypeHandle; -import io.trino.spi.connector.AggregateFunction; -import io.trino.spi.expression.Variable; - -import java.sql.Types; -import java.util.Optional; - -import static com.google.common.base.Verify.verify; -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static java.lang.String.format; +import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint; public class ImplementAvgBigint - implements AggregateFunctionRule + extends BaseImplementAvgBigint { - private static final Capture INPUT = newCapture(); - - @Override - public Pattern getPattern() - { - return basicAggregation() - .with(functionName().equalTo("avg")) - .with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT))); - } - @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + protected String getRewriteFormatExpression() { - Variable input = captures.get(INPUT); - JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); - verify(aggregateFunction.getOutputType() == DOUBLE); - - return Optional.of(new JdbcExpression( - format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), - new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); + return "avg((%s * 1.0))"; } } diff --git a/plugin/trino-mysql/pom.xml b/plugin/trino-mysql/pom.xml index 77f59ee2089e..5129792c8fe1 100644 --- a/plugin/trino-mysql/pom.xml +++ b/plugin/trino-mysql/pom.xml @@ -23,11 +23,6 @@ trino-base-jdbc - - io.trino - trino-matching - - io.trino trino-plugin-toolkit diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java index c6eba8716efe..9c5d45a35085 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/ImplementAvgBigint.java @@ -13,55 +13,14 @@ */ package io.trino.plugin.mysql; -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.jdbc.JdbcColumnHandle; -import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.plugin.jdbc.JdbcTypeHandle; -import io.trino.spi.connector.AggregateFunction; -import io.trino.spi.expression.Variable; - -import java.sql.Types; -import java.util.Optional; - -import static com.google.common.base.Verify.verify; -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static java.lang.String.format; +import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint; public class ImplementAvgBigint - implements AggregateFunctionRule + extends BaseImplementAvgBigint { - private static final Capture INPUT = newCapture(); - - @Override - public Pattern getPattern() - { - return basicAggregation() - .with(functionName().equalTo("avg")) - .with(singleInput().matching( - variable() - .with(expressionType().matching(type -> type == BIGINT)) - .capturedAs(INPUT))); - } - @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + protected String getRewriteFormatExpression() { - Variable input = captures.get(INPUT); - JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); - verify(aggregateFunction.getOutputType() == DOUBLE); - - return Optional.of(new JdbcExpression( - format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), - new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); + return "avg((%s * 1.0))"; } } diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index c1216fbcb99d..1b10cba5ae2c 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -142,6 +142,11 @@ public class OracleClient private final boolean synonymsEnabled; + /** + * Note the type mappings from trino -> oracle types can cause surprises since they are not invertible + * For example, creating an oracle table in trino with a bigint column will generate an oracle table with a number column + * Then querying the oracle table with the number column will return a decimal (not a bigint) + */ private static final Map WRITE_MAPPINGS = ImmutableMap.builder() .put(BOOLEAN, oracleBooleanWriteMapping()) .put(BIGINT, WriteMapping.longMapping("number(19)", bigintWriteFunction())) diff --git a/plugin/trino-postgresql/pom.xml b/plugin/trino-postgresql/pom.xml index a4d847a159bc..91c0ba63913f 100644 --- a/plugin/trino-postgresql/pom.xml +++ b/plugin/trino-postgresql/pom.xml @@ -23,11 +23,6 @@ trino-base-jdbc - - io.trino - trino-matching - - io.trino trino-plugin-toolkit diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java index f17a8c10adca..a6d7a03949ba 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/ImplementAvgBigint.java @@ -13,52 +13,14 @@ */ package io.trino.plugin.postgresql; -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.jdbc.JdbcColumnHandle; -import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.plugin.jdbc.JdbcTypeHandle; -import io.trino.spi.connector.AggregateFunction; -import io.trino.spi.expression.Variable; - -import java.sql.Types; -import java.util.Optional; - -import static com.google.common.base.Verify.verify; -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static java.lang.String.format; +import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint; public class ImplementAvgBigint - implements AggregateFunctionRule + extends BaseImplementAvgBigint { - private static final Capture INPUT = newCapture(); - - @Override - public Pattern getPattern() - { - return basicAggregation() - .with(functionName().equalTo("avg")) - .with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT))); - } - @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + protected String getRewriteFormatExpression() { - Variable input = captures.get(INPUT); - JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); - verify(aggregateFunction.getOutputType() == DOUBLE); - - return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), - new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); + return "avg(CAST(%s AS double precision))"; } } diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java index ad80b4493fcc..4c4b8e8e9dbd 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/ImplementAvgBigint.java @@ -13,52 +13,14 @@ */ package io.trino.plugin.sqlserver; -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.AggregateFunctionRule; -import io.trino.plugin.jdbc.JdbcColumnHandle; -import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.plugin.jdbc.JdbcTypeHandle; -import io.trino.spi.connector.AggregateFunction; -import io.trino.spi.expression.Variable; - -import java.sql.Types; -import java.util.Optional; - -import static com.google.common.base.Verify.verify; -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.basicAggregation; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.expressionType; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.functionName; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.singleInput; -import static io.trino.plugin.base.expression.AggregateFunctionPatterns.variable; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static java.lang.String.format; +import io.trino.plugin.jdbc.expression.BaseImplementAvgBigint; public class ImplementAvgBigint - implements AggregateFunctionRule + extends BaseImplementAvgBigint { - private static final Capture INPUT = newCapture(); - - @Override - public Pattern getPattern() - { - return basicAggregation() - .with(functionName().equalTo("avg")) - .with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT))); - } - @Override - public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + protected String getRewriteFormatExpression() { - Variable input = captures.get(INPUT); - JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); - verify(aggregateFunction.getOutputType() == DOUBLE); - - return Optional.of(new JdbcExpression( - format("avg(CAST(%s AS double precision))", context.getIdentifierQuote().apply(columnHandle.getColumnName())), - new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))); + return "avg(CAST(%s AS double precision))"; } }