diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java new file mode 100644 index 000000000000..a860d3e6da14 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -0,0 +1,399 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinType; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.Type; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class DefaultQueryBuilder +{ + private static final Logger log = Logger.get(DefaultQueryBuilder.class); + + // not all databases support booleans, so use 1=1 and 1=0 instead + private static final String ALWAYS_TRUE = "1=1"; + private static final String ALWAYS_FALSE = "1=0"; + + private final JdbcClient client; + + public DefaultQueryBuilder(JdbcClient client) + { + this.client = requireNonNull(client, "client is null"); + } + + public PreparedQuery prepareQuery( + JdbcClient client, + ConnectorSession session, + Connection connection, + JdbcRelationHandle baseRelation, + Optional>> groupingSets, + List columns, + Map columnExpressions, + TupleDomain tupleDomain, + Optional additionalPredicate) + { + if (!tupleDomain.isNone()) { + Map domains = tupleDomain.getDomains().orElseThrow(); + columns.stream() + .filter(domains::containsKey) + .filter(column -> columnExpressions.containsKey(column.getColumnName())) + .findFirst() + .ifPresent(column -> { throw new IllegalArgumentException(format("Column %s has an expression and a constraint attached at the same time", column)); }); + } + + ImmutableList.Builder accumulator = ImmutableList.builder(); + + String sql = "SELECT " + getProjection(columns, columnExpressions); + sql += getFrom(baseRelation, accumulator::add); + + List clauses = toConjuncts(session, connection, tupleDomain, accumulator::add); + if (additionalPredicate.isPresent()) { + clauses = ImmutableList.builder() + .addAll(clauses) + .add(additionalPredicate.get()) + .build(); + } + if (!clauses.isEmpty()) { + sql += " WHERE " + Joiner.on(" AND ").join(clauses); + } + + sql += getGroupBy(groupingSets); + + return new PreparedQuery(sql, accumulator.build()); + } + + public PreparedQuery prepareJoinQuery( + JdbcClient client, + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map leftAssignments, + Map rightAssignments) + { + // Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here. + verify(!leftAssignments.isEmpty(), "leftAssignments is empty"); + verify(!rightAssignments.isEmpty(), "rightAssignments is empty"); + // Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here + verify(!joinConditions.isEmpty(), "joinConditions is empty"); + + String query = format( + "SELECT %s, %s FROM (%s) l %s (%s) r ON %s", + formatAssignments("l", leftAssignments), + formatAssignments("r", rightAssignments), + leftSource.getQuery(), + formatJoinType(joinType), + rightSource.getQuery(), + joinConditions.stream() + .map(condition -> format( + "l.%s %s r.%s", + client.quoted(condition.getLeftColumn().getColumnName()), + condition.getOperator().getValue(), + client.quoted(condition.getRightColumn().getColumnName()))) + .collect(joining(" AND "))); + List parameters = ImmutableList.builder() + .addAll(leftSource.getParameters()) + .addAll(rightSource.getParameters()) + .build(); + return new PreparedQuery(query, parameters); + } + + protected String formatAssignments(String relationAlias, Map assignments) + { + return assignments.entrySet().stream() + .map(entry -> format("%s.%s AS %s", relationAlias, client.quoted(entry.getKey().getColumnName()), client.quoted(entry.getValue()))) + .collect(joining(", ")); + } + + protected static String formatJoinType(JoinType joinType) + { + switch (joinType) { + case INNER: + return "INNER JOIN"; + case LEFT_OUTER: + return "LEFT JOIN"; + case RIGHT_OUTER: + return "RIGHT JOIN"; + case FULL_OUTER: + return "FULL JOIN"; + } + throw new IllegalStateException("Unsupported join type: " + joinType); + } + + public PreparedQuery prepareDelete( + JdbcClient client, + ConnectorSession session, + Connection connection, + JdbcNamedRelationHandle baseRelation, + TupleDomain tupleDomain) + { + String sql = "DELETE FROM " + getRelation(baseRelation.getRemoteTableName()); + + ImmutableList.Builder accumulator = ImmutableList.builder(); + + List clauses = toConjuncts(session, connection, tupleDomain, accumulator::add); + if (!clauses.isEmpty()) { + sql += " WHERE " + Joiner.on(" AND ").join(clauses); + } + return new PreparedQuery(sql, accumulator.build()); + } + + public PreparedStatement prepareStatement( + JdbcClient client, + ConnectorSession session, + Connection connection, + PreparedQuery preparedQuery) + throws SQLException + { + log.debug("Preparing query: %s", preparedQuery.getQuery()); + PreparedStatement statement = client.getPreparedStatement(connection, preparedQuery.getQuery()); + + List parameters = preparedQuery.getParameters(); + for (int i = 0; i < parameters.size(); i++) { + QueryParameter parameter = parameters.get(i); + int parameterIndex = i + 1; + WriteFunction writeFunction = getWriteFunction(session, connection, parameter.getJdbcType(), parameter.getType()); + Class javaType = writeFunction.getJavaType(); + Object value = parameter.getValue() + // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and + // nullable domains are handled explicitly, with SQL syntax. + .orElseThrow(() -> new VerifyException("Value is missing")); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); + } + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); + } + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); + } + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); + } + else { + ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + } + } + + return statement; + } + + protected String getRelation(RemoteTableName remoteTableName) + { + return client.quoted(remoteTableName); + } + + protected String getProjection(List columns, Map columnExpressions) + { + if (columns.isEmpty()) { + return "1 x"; + } + return columns.stream() + .map(jdbcColumnHandle -> { + String columnAlias = client.quoted(jdbcColumnHandle.getColumnName()); + String expression = columnExpressions.get(jdbcColumnHandle.getColumnName()); + if (expression == null) { + return columnAlias; + } + return format("%s AS %s", expression, columnAlias); + }) + .collect(joining(", ")); + } + + private String getFrom(JdbcRelationHandle baseRelation, Consumer accumulator) + { + if (baseRelation instanceof JdbcNamedRelationHandle) { + return " FROM " + getRelation(((JdbcNamedRelationHandle) baseRelation).getRemoteTableName()); + } + if (baseRelation instanceof JdbcQueryRelationHandle) { + PreparedQuery preparedQuery = ((JdbcQueryRelationHandle) baseRelation).getPreparedQuery(); + preparedQuery.getParameters().forEach(accumulator); + return " FROM (" + preparedQuery.getQuery() + ") o"; + } + throw new IllegalArgumentException("Unsupported relation: " + baseRelation); + } + + private static Domain pushDownDomain(JdbcClient client, ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain) + { + return client.toColumnMapping(session, connection, column.getJdbcTypeHandle()) + .orElseThrow(() -> new IllegalStateException(format("Unsupported type %s with handle %s", column.getColumnType(), column.getJdbcTypeHandle()))) + .getPredicatePushdownController().apply(session, domain).getPushedDown(); + } + + private List toConjuncts( + ConnectorSession session, + Connection connection, + TupleDomain tupleDomain, + Consumer accumulator) + { + if (tupleDomain.isNone()) { + return ImmutableList.of(ALWAYS_FALSE); + } + ImmutableList.Builder builder = ImmutableList.builder(); + for (Map.Entry entry : tupleDomain.getDomains().get().entrySet()) { + JdbcColumnHandle column = ((JdbcColumnHandle) entry.getKey()); + Domain domain = pushDownDomain(client, session, connection, column, entry.getValue()); + builder.add(toPredicate(session, connection, column, domain, accumulator)); + } + return builder.build(); + } + + private String toPredicate(ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain, Consumer accumulator) + { + if (domain.getValues().isNone()) { + return domain.isNullAllowed() ? client.quoted(column.getColumnName()) + " IS NULL" : ALWAYS_FALSE; + } + + if (domain.getValues().isAll()) { + return domain.isNullAllowed() ? ALWAYS_TRUE : client.quoted(column.getColumnName()) + " IS NOT NULL"; + } + + String predicate = toPredicate(session, connection, column, domain.getValues(), accumulator); + if (!domain.isNullAllowed()) { + return predicate; + } + return format("(%s OR %s IS NULL)", predicate, client.quoted(column.getColumnName())); + } + + private String toPredicate(ConnectorSession session, Connection connection, JdbcColumnHandle column, ValueSet valueSet, Consumer accumulator) + { + checkArgument(!valueSet.isNone(), "none values should be handled earlier"); + + if (!valueSet.isDiscreteSet()) { + ValueSet complement = valueSet.complement(); + if (complement.isDiscreteSet()) { + return format("NOT (%s)", toPredicate(session, connection, column, complement, accumulator)); + } + } + + JdbcTypeHandle jdbcType = column.getJdbcTypeHandle(); + Type type = column.getColumnType(); + WriteFunction writeFunction = getWriteFunction(session, connection, jdbcType, type); + + List disjuncts = new ArrayList<>(); + List singleValues = new ArrayList<>(); + for (Range range : valueSet.getRanges().getOrderedRanges()) { + checkState(!range.isAll()); // Already checked + if (range.isSingleValue()) { + singleValues.add(range.getSingleValue()); + } + else { + List rangeConjuncts = new ArrayList<>(); + if (!range.isLowUnbounded()) { + rangeConjuncts.add(toPredicate(column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator)); + } + if (!range.isHighUnbounded()) { + rangeConjuncts.add(toPredicate(column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator)); + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + if (rangeConjuncts.size() == 1) { + disjuncts.add(getOnlyElement(rangeConjuncts)); + } + else { + disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); + } + } + } + + // Add back all of the possible single values either as an equality or an IN predicate + if (singleValues.size() == 1) { + disjuncts.add(toPredicate(column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator)); + } + else if (singleValues.size() > 1) { + for (Object value : singleValues) { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + } + String values = Joiner.on(",").join(nCopies(singleValues.size(), writeFunction.getBindExpression())); + disjuncts.add(client.quoted(column.getColumnName()) + " IN (" + values + ")"); + } + + checkState(!disjuncts.isEmpty()); + if (disjuncts.size() == 1) { + return getOnlyElement(disjuncts); + } + return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; + } + + private String toPredicate(JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) + { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + return format("%s %s %s", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); + } + + private WriteFunction getWriteFunction(ConnectorSession session, Connection connection, JdbcTypeHandle jdbcType, Type type) + { + WriteFunction writeFunction = client.toColumnMapping(session, connection, jdbcType) + .orElseThrow(() -> new VerifyException(format("Unsupported type %s with handle %s", type, jdbcType))) + .getWriteFunction(); + verify(writeFunction.getJavaType() == type.getJavaType(), "Java type mismatch: %s, %s", writeFunction, type); + return writeFunction; + } + + private String getGroupBy(Optional>> groupingSets) + { + if (groupingSets.isEmpty()) { + return ""; + } + + verify(!groupingSets.get().isEmpty()); + if (groupingSets.get().size() == 1) { + List groupingSet = getOnlyElement(groupingSets.get()); + if (groupingSet.isEmpty()) { + // global aggregation + return ""; + } + return " GROUP BY " + groupingSet.stream() + .map(JdbcColumnHandle::getColumnName) + .map(client::quoted) + .collect(joining(", ")); + } + return " GROUP BY GROUPING SETS " + + groupingSets.get().stream() + .map(groupingSet -> groupingSet.stream() + .map(JdbcColumnHandle::getColumnName) + .map(client::quoted) + .collect(joining(", ", "(", ")"))) + .collect(joining(", ", "(", ")")); + } +}