-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create copy of QueryBuilder class as DefaultQueryBuilder
- Loading branch information
Showing
1 changed file
with
399 additions
and
0 deletions.
There are no files selected for viewing
399 changes: 399 additions & 0 deletions
399
plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,399 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc; | ||
|
||
import com.google.common.base.Joiner; | ||
import com.google.common.base.VerifyException; | ||
import com.google.common.collect.ImmutableList; | ||
import io.airlift.log.Logger; | ||
import io.airlift.slice.Slice; | ||
import io.trino.spi.connector.ColumnHandle; | ||
import io.trino.spi.connector.ConnectorSession; | ||
import io.trino.spi.connector.JoinType; | ||
import io.trino.spi.predicate.Domain; | ||
import io.trino.spi.predicate.Range; | ||
import io.trino.spi.predicate.TupleDomain; | ||
import io.trino.spi.predicate.ValueSet; | ||
import io.trino.spi.type.Type; | ||
|
||
import java.sql.Connection; | ||
import java.sql.PreparedStatement; | ||
import java.sql.SQLException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.function.Consumer; | ||
|
||
import static com.google.common.base.Preconditions.checkArgument; | ||
import static com.google.common.base.Preconditions.checkState; | ||
import static com.google.common.base.Verify.verify; | ||
import static com.google.common.collect.Iterables.getOnlyElement; | ||
import static java.lang.String.format; | ||
import static java.util.Collections.nCopies; | ||
import static java.util.Objects.requireNonNull; | ||
import static java.util.stream.Collectors.joining; | ||
|
||
public class DefaultQueryBuilder | ||
{ | ||
private static final Logger log = Logger.get(DefaultQueryBuilder.class); | ||
|
||
// not all databases support booleans, so use 1=1 and 1=0 instead | ||
private static final String ALWAYS_TRUE = "1=1"; | ||
private static final String ALWAYS_FALSE = "1=0"; | ||
|
||
private final JdbcClient client; | ||
|
||
public DefaultQueryBuilder(JdbcClient client) | ||
{ | ||
this.client = requireNonNull(client, "client is null"); | ||
} | ||
|
||
public PreparedQuery prepareQuery( | ||
JdbcClient client, | ||
ConnectorSession session, | ||
Connection connection, | ||
JdbcRelationHandle baseRelation, | ||
Optional<List<List<JdbcColumnHandle>>> groupingSets, | ||
List<JdbcColumnHandle> columns, | ||
Map<String, String> columnExpressions, | ||
TupleDomain<ColumnHandle> tupleDomain, | ||
Optional<String> additionalPredicate) | ||
{ | ||
if (!tupleDomain.isNone()) { | ||
Map<ColumnHandle, Domain> domains = tupleDomain.getDomains().orElseThrow(); | ||
columns.stream() | ||
.filter(domains::containsKey) | ||
.filter(column -> columnExpressions.containsKey(column.getColumnName())) | ||
.findFirst() | ||
.ifPresent(column -> { throw new IllegalArgumentException(format("Column %s has an expression and a constraint attached at the same time", column)); }); | ||
} | ||
|
||
ImmutableList.Builder<QueryParameter> accumulator = ImmutableList.builder(); | ||
|
||
String sql = "SELECT " + getProjection(columns, columnExpressions); | ||
sql += getFrom(baseRelation, accumulator::add); | ||
|
||
List<String> clauses = toConjuncts(session, connection, tupleDomain, accumulator::add); | ||
if (additionalPredicate.isPresent()) { | ||
clauses = ImmutableList.<String>builder() | ||
.addAll(clauses) | ||
.add(additionalPredicate.get()) | ||
.build(); | ||
} | ||
if (!clauses.isEmpty()) { | ||
sql += " WHERE " + Joiner.on(" AND ").join(clauses); | ||
} | ||
|
||
sql += getGroupBy(groupingSets); | ||
|
||
return new PreparedQuery(sql, accumulator.build()); | ||
} | ||
|
||
public PreparedQuery prepareJoinQuery( | ||
JdbcClient client, | ||
ConnectorSession session, | ||
JoinType joinType, | ||
PreparedQuery leftSource, | ||
PreparedQuery rightSource, | ||
List<JdbcJoinCondition> joinConditions, | ||
Map<JdbcColumnHandle, String> leftAssignments, | ||
Map<JdbcColumnHandle, String> rightAssignments) | ||
{ | ||
// Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here. | ||
verify(!leftAssignments.isEmpty(), "leftAssignments is empty"); | ||
verify(!rightAssignments.isEmpty(), "rightAssignments is empty"); | ||
// Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here | ||
verify(!joinConditions.isEmpty(), "joinConditions is empty"); | ||
|
||
String query = format( | ||
"SELECT %s, %s FROM (%s) l %s (%s) r ON %s", | ||
formatAssignments("l", leftAssignments), | ||
formatAssignments("r", rightAssignments), | ||
leftSource.getQuery(), | ||
formatJoinType(joinType), | ||
rightSource.getQuery(), | ||
joinConditions.stream() | ||
.map(condition -> format( | ||
"l.%s %s r.%s", | ||
client.quoted(condition.getLeftColumn().getColumnName()), | ||
condition.getOperator().getValue(), | ||
client.quoted(condition.getRightColumn().getColumnName()))) | ||
.collect(joining(" AND "))); | ||
List<QueryParameter> parameters = ImmutableList.<QueryParameter>builder() | ||
.addAll(leftSource.getParameters()) | ||
.addAll(rightSource.getParameters()) | ||
.build(); | ||
return new PreparedQuery(query, parameters); | ||
} | ||
|
||
protected String formatAssignments(String relationAlias, Map<JdbcColumnHandle, String> assignments) | ||
{ | ||
return assignments.entrySet().stream() | ||
.map(entry -> format("%s.%s AS %s", relationAlias, client.quoted(entry.getKey().getColumnName()), client.quoted(entry.getValue()))) | ||
.collect(joining(", ")); | ||
} | ||
|
||
protected static String formatJoinType(JoinType joinType) | ||
{ | ||
switch (joinType) { | ||
case INNER: | ||
return "INNER JOIN"; | ||
case LEFT_OUTER: | ||
return "LEFT JOIN"; | ||
case RIGHT_OUTER: | ||
return "RIGHT JOIN"; | ||
case FULL_OUTER: | ||
return "FULL JOIN"; | ||
} | ||
throw new IllegalStateException("Unsupported join type: " + joinType); | ||
} | ||
|
||
public PreparedQuery prepareDelete( | ||
JdbcClient client, | ||
ConnectorSession session, | ||
Connection connection, | ||
JdbcNamedRelationHandle baseRelation, | ||
TupleDomain<ColumnHandle> tupleDomain) | ||
{ | ||
String sql = "DELETE FROM " + getRelation(baseRelation.getRemoteTableName()); | ||
|
||
ImmutableList.Builder<QueryParameter> accumulator = ImmutableList.builder(); | ||
|
||
List<String> clauses = toConjuncts(session, connection, tupleDomain, accumulator::add); | ||
if (!clauses.isEmpty()) { | ||
sql += " WHERE " + Joiner.on(" AND ").join(clauses); | ||
} | ||
return new PreparedQuery(sql, accumulator.build()); | ||
} | ||
|
||
public PreparedStatement prepareStatement( | ||
JdbcClient client, | ||
ConnectorSession session, | ||
Connection connection, | ||
PreparedQuery preparedQuery) | ||
throws SQLException | ||
{ | ||
log.debug("Preparing query: %s", preparedQuery.getQuery()); | ||
PreparedStatement statement = client.getPreparedStatement(connection, preparedQuery.getQuery()); | ||
|
||
List<QueryParameter> parameters = preparedQuery.getParameters(); | ||
for (int i = 0; i < parameters.size(); i++) { | ||
QueryParameter parameter = parameters.get(i); | ||
int parameterIndex = i + 1; | ||
WriteFunction writeFunction = getWriteFunction(session, connection, parameter.getJdbcType(), parameter.getType()); | ||
Class<?> javaType = writeFunction.getJavaType(); | ||
Object value = parameter.getValue() | ||
// The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and | ||
// nullable domains are handled explicitly, with SQL syntax. | ||
.orElseThrow(() -> new VerifyException("Value is missing")); | ||
if (javaType == boolean.class) { | ||
((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); | ||
} | ||
else if (javaType == long.class) { | ||
((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); | ||
} | ||
else if (javaType == double.class) { | ||
((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); | ||
} | ||
else if (javaType == Slice.class) { | ||
((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); | ||
} | ||
else { | ||
((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); | ||
} | ||
} | ||
|
||
return statement; | ||
} | ||
|
||
protected String getRelation(RemoteTableName remoteTableName) | ||
{ | ||
return client.quoted(remoteTableName); | ||
} | ||
|
||
protected String getProjection(List<JdbcColumnHandle> columns, Map<String, String> columnExpressions) | ||
{ | ||
if (columns.isEmpty()) { | ||
return "1 x"; | ||
} | ||
return columns.stream() | ||
.map(jdbcColumnHandle -> { | ||
String columnAlias = client.quoted(jdbcColumnHandle.getColumnName()); | ||
String expression = columnExpressions.get(jdbcColumnHandle.getColumnName()); | ||
if (expression == null) { | ||
return columnAlias; | ||
} | ||
return format("%s AS %s", expression, columnAlias); | ||
}) | ||
.collect(joining(", ")); | ||
} | ||
|
||
private String getFrom(JdbcRelationHandle baseRelation, Consumer<QueryParameter> accumulator) | ||
{ | ||
if (baseRelation instanceof JdbcNamedRelationHandle) { | ||
return " FROM " + getRelation(((JdbcNamedRelationHandle) baseRelation).getRemoteTableName()); | ||
} | ||
if (baseRelation instanceof JdbcQueryRelationHandle) { | ||
PreparedQuery preparedQuery = ((JdbcQueryRelationHandle) baseRelation).getPreparedQuery(); | ||
preparedQuery.getParameters().forEach(accumulator); | ||
return " FROM (" + preparedQuery.getQuery() + ") o"; | ||
} | ||
throw new IllegalArgumentException("Unsupported relation: " + baseRelation); | ||
} | ||
|
||
private static Domain pushDownDomain(JdbcClient client, ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain) | ||
{ | ||
return client.toColumnMapping(session, connection, column.getJdbcTypeHandle()) | ||
.orElseThrow(() -> new IllegalStateException(format("Unsupported type %s with handle %s", column.getColumnType(), column.getJdbcTypeHandle()))) | ||
.getPredicatePushdownController().apply(session, domain).getPushedDown(); | ||
} | ||
|
||
private List<String> toConjuncts( | ||
ConnectorSession session, | ||
Connection connection, | ||
TupleDomain<ColumnHandle> tupleDomain, | ||
Consumer<QueryParameter> accumulator) | ||
{ | ||
if (tupleDomain.isNone()) { | ||
return ImmutableList.of(ALWAYS_FALSE); | ||
} | ||
ImmutableList.Builder<String> builder = ImmutableList.builder(); | ||
for (Map.Entry<ColumnHandle, Domain> entry : tupleDomain.getDomains().get().entrySet()) { | ||
JdbcColumnHandle column = ((JdbcColumnHandle) entry.getKey()); | ||
Domain domain = pushDownDomain(client, session, connection, column, entry.getValue()); | ||
builder.add(toPredicate(session, connection, column, domain, accumulator)); | ||
} | ||
return builder.build(); | ||
} | ||
|
||
private String toPredicate(ConnectorSession session, Connection connection, JdbcColumnHandle column, Domain domain, Consumer<QueryParameter> accumulator) | ||
{ | ||
if (domain.getValues().isNone()) { | ||
return domain.isNullAllowed() ? client.quoted(column.getColumnName()) + " IS NULL" : ALWAYS_FALSE; | ||
} | ||
|
||
if (domain.getValues().isAll()) { | ||
return domain.isNullAllowed() ? ALWAYS_TRUE : client.quoted(column.getColumnName()) + " IS NOT NULL"; | ||
} | ||
|
||
String predicate = toPredicate(session, connection, column, domain.getValues(), accumulator); | ||
if (!domain.isNullAllowed()) { | ||
return predicate; | ||
} | ||
return format("(%s OR %s IS NULL)", predicate, client.quoted(column.getColumnName())); | ||
} | ||
|
||
private String toPredicate(ConnectorSession session, Connection connection, JdbcColumnHandle column, ValueSet valueSet, Consumer<QueryParameter> accumulator) | ||
{ | ||
checkArgument(!valueSet.isNone(), "none values should be handled earlier"); | ||
|
||
if (!valueSet.isDiscreteSet()) { | ||
ValueSet complement = valueSet.complement(); | ||
if (complement.isDiscreteSet()) { | ||
return format("NOT (%s)", toPredicate(session, connection, column, complement, accumulator)); | ||
} | ||
} | ||
|
||
JdbcTypeHandle jdbcType = column.getJdbcTypeHandle(); | ||
Type type = column.getColumnType(); | ||
WriteFunction writeFunction = getWriteFunction(session, connection, jdbcType, type); | ||
|
||
List<String> disjuncts = new ArrayList<>(); | ||
List<Object> singleValues = new ArrayList<>(); | ||
for (Range range : valueSet.getRanges().getOrderedRanges()) { | ||
checkState(!range.isAll()); // Already checked | ||
if (range.isSingleValue()) { | ||
singleValues.add(range.getSingleValue()); | ||
} | ||
else { | ||
List<String> rangeConjuncts = new ArrayList<>(); | ||
if (!range.isLowUnbounded()) { | ||
rangeConjuncts.add(toPredicate(column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator)); | ||
} | ||
if (!range.isHighUnbounded()) { | ||
rangeConjuncts.add(toPredicate(column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator)); | ||
} | ||
// If rangeConjuncts is null, then the range was ALL, which should already have been checked for | ||
checkState(!rangeConjuncts.isEmpty()); | ||
if (rangeConjuncts.size() == 1) { | ||
disjuncts.add(getOnlyElement(rangeConjuncts)); | ||
} | ||
else { | ||
disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); | ||
} | ||
} | ||
} | ||
|
||
// Add back all of the possible single values either as an equality or an IN predicate | ||
if (singleValues.size() == 1) { | ||
disjuncts.add(toPredicate(column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator)); | ||
} | ||
else if (singleValues.size() > 1) { | ||
for (Object value : singleValues) { | ||
accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); | ||
} | ||
String values = Joiner.on(",").join(nCopies(singleValues.size(), writeFunction.getBindExpression())); | ||
disjuncts.add(client.quoted(column.getColumnName()) + " IN (" + values + ")"); | ||
} | ||
|
||
checkState(!disjuncts.isEmpty()); | ||
if (disjuncts.size() == 1) { | ||
return getOnlyElement(disjuncts); | ||
} | ||
return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; | ||
} | ||
|
||
private String toPredicate(JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator) | ||
{ | ||
accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); | ||
return format("%s %s %s", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); | ||
} | ||
|
||
private WriteFunction getWriteFunction(ConnectorSession session, Connection connection, JdbcTypeHandle jdbcType, Type type) | ||
{ | ||
WriteFunction writeFunction = client.toColumnMapping(session, connection, jdbcType) | ||
.orElseThrow(() -> new VerifyException(format("Unsupported type %s with handle %s", type, jdbcType))) | ||
.getWriteFunction(); | ||
verify(writeFunction.getJavaType() == type.getJavaType(), "Java type mismatch: %s, %s", writeFunction, type); | ||
return writeFunction; | ||
} | ||
|
||
private String getGroupBy(Optional<List<List<JdbcColumnHandle>>> groupingSets) | ||
{ | ||
if (groupingSets.isEmpty()) { | ||
return ""; | ||
} | ||
|
||
verify(!groupingSets.get().isEmpty()); | ||
if (groupingSets.get().size() == 1) { | ||
List<JdbcColumnHandle> groupingSet = getOnlyElement(groupingSets.get()); | ||
if (groupingSet.isEmpty()) { | ||
// global aggregation | ||
return ""; | ||
} | ||
return " GROUP BY " + groupingSet.stream() | ||
.map(JdbcColumnHandle::getColumnName) | ||
.map(client::quoted) | ||
.collect(joining(", ")); | ||
} | ||
return " GROUP BY GROUPING SETS " + | ||
groupingSets.get().stream() | ||
.map(groupingSet -> groupingSet.stream() | ||
.map(JdbcColumnHandle::getColumnName) | ||
.map(client::quoted) | ||
.collect(joining(", ", "(", ")"))) | ||
.collect(joining(", ", "(", ")")); | ||
} | ||
} |