Skip to content

Commit

Permalink
Use JDBC parameters in JDBC complex expression pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Mar 22, 2023
1 parent 757aa68 commit e3bbf6c
Show file tree
Hide file tree
Showing 58 changed files with 617 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.io.Closer;
import io.airlift.log.Logger;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.spi.TrinoException;
Expand Down Expand Up @@ -421,7 +422,7 @@ public PreparedQuery prepareQuery(
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
Map<String, ParameterizedExpression> columnExpressions)
{
verify(table.getAuthorization().isEmpty(), "Unexpected authorization is required for table: %s".formatted(table));
try (Connection connection = connectionFactory.openConnection(session)) {
Expand All @@ -446,7 +447,7 @@ protected PreparedQuery prepareQuery(
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions,
Map<String, ParameterizedExpression> columnExpressions,
Optional<JdbcSplit> split)
{
return applyQueryTransformations(table, queryBuilder.prepareSelectQuery(
Expand All @@ -461,15 +462,18 @@ protected PreparedQuery prepareQuery(
getAdditionalPredicate(table.getConstraintExpressions(), split.flatMap(JdbcSplit::getAdditionalPredicate))));
}

protected static Optional<String> getAdditionalPredicate(List<String> constraintExpressions, Optional<String> splitPredicate)
protected static Optional<ParameterizedExpression> getAdditionalPredicate(List<ParameterizedExpression> constraintExpressions, Optional<String> splitPredicate)
{
if (constraintExpressions.isEmpty() && splitPredicate.isEmpty()) {
return Optional.empty();
}

return Optional.of(
Stream.concat(constraintExpressions.stream(), splitPredicate.stream())
.collect(joining(") AND (", "(", ")")));
return Optional.of(new ParameterizedExpression(
Stream.concat(constraintExpressions.stream().map(ParameterizedExpression::expression), splitPredicate.stream())
.collect(joining(") AND (", "(", ")")),
constraintExpressions.stream()
.flatMap(expressionRewrite -> expressionRewrite.parameters().stream())
.collect(toImmutableList())));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.collect.cache.EvictableCacheBuilder;
import io.trino.plugin.base.session.SessionPropertiesProvider;
import io.trino.plugin.jdbc.IdentityCacheMapping.IdentityCacheKey;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
Expand Down Expand Up @@ -224,7 +225,7 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
}

@Override
public Optional<String> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
public Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return delegate.convertPredicate(session, expression, assignments);
}
Expand Down Expand Up @@ -255,7 +256,7 @@ public PreparedQuery prepareQuery(
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
Map<String, ParameterizedExpression> columnExpressions)
{
return delegate.prepareQuery(session, table, groupingSets, columns, columnExpressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.PredicatePushdownController.DomainPushdownResult;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.ptf.Query.QueryFunctionHandle;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
Expand Down Expand Up @@ -157,7 +158,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C

TupleDomain<ColumnHandle> oldDomain = handle.getConstraint();
TupleDomain<ColumnHandle> newDomain = oldDomain.intersect(constraint.getSummary());
List<String> newConstraintExpressions;
List<ParameterizedExpression> newConstraintExpressions;
TupleDomain<ColumnHandle> remainingFilter;
Optional<ConnectorExpression> remainingExpression;
if (newDomain.isNone()) {
Expand Down Expand Up @@ -190,18 +191,18 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
remainingFilter = TupleDomain.withColumnDomains(unsupported);

if (isComplexExpressionPushdown(session)) {
List<String> newExpressions = new ArrayList<>();
List<ParameterizedExpression> newExpressions = new ArrayList<>();
List<ConnectorExpression> remainingExpressions = new ArrayList<>();
for (ConnectorExpression expression : extractConjuncts(constraint.getExpression())) {
Optional<String> converted = jdbcClient.convertPredicate(session, expression, constraint.getAssignments());
Optional<ParameterizedExpression> converted = jdbcClient.convertPredicate(session, expression, constraint.getAssignments());
if (converted.isPresent()) {
newExpressions.add(converted.get());
}
else {
remainingExpressions.add(expression);
}
}
newConstraintExpressions = ImmutableSet.<String>builder()
newConstraintExpressions = ImmutableSet.<ParameterizedExpression>builder()
.addAll(handle.getConstraintExpressions())
.addAll(newExpressions)
.build().asList();
Expand Down Expand Up @@ -337,7 +338,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ImmutableList.Builder<JdbcColumnHandle> newColumns = ImmutableList.builder();
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();
ImmutableMap.Builder<String, ParameterizedExpression> expressions = ImmutableMap.builder();

List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream()
.map(groupingSet -> groupingSet.stream()
Expand Down Expand Up @@ -374,7 +375,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
newColumns.add(newColumn);
projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType()));
resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType()));
expressions.put(columnName, expression.get().getExpression());
expressions.put(columnName, new ParameterizedExpression(expression.get().getExpression(), expression.get().getParameters()));
}

List<JdbcColumnHandle> newColumnsList = newColumns.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
Expand Down Expand Up @@ -72,9 +73,9 @@ public PreparedQuery prepareSelectQuery(
JdbcRelationHandle baseRelation,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions,
Map<String, ParameterizedExpression> columnExpressions,
TupleDomain<ColumnHandle> tupleDomain,
Optional<String> additionalPredicate)
Optional<ParameterizedExpression> additionalPredicate)
{
if (!tupleDomain.isNone()) {
Map<ColumnHandle, Domain> domains = tupleDomain.getDomains().orElseThrow();
Expand All @@ -88,11 +89,14 @@ public PreparedQuery prepareSelectQuery(
ImmutableList.Builder<String> conjuncts = ImmutableList.builder();
ImmutableList.Builder<QueryParameter> accumulator = ImmutableList.builder();

String sql = "SELECT " + getProjection(client, columns, columnExpressions);
String sql = "SELECT " + getProjection(client, columns, columnExpressions, accumulator::add);
sql += getFrom(client, baseRelation, accumulator::add);

toConjuncts(client, session, connection, tupleDomain, conjuncts, accumulator::add);
additionalPredicate.ifPresent(conjuncts::add);
additionalPredicate.ifPresent(predicate -> {
conjuncts.add(predicate.expression());
accumulator.addAll(predicate.parameters());
});
List<String> clauses = conjuncts.build();
if (!clauses.isEmpty()) {
sql += " WHERE " + Joiner.on(" AND ").join(clauses);
Expand Down Expand Up @@ -150,15 +154,18 @@ public PreparedQuery prepareDeleteQuery(
Connection connection,
JdbcNamedRelationHandle baseRelation,
TupleDomain<ColumnHandle> tupleDomain,
Optional<String> additionalPredicate)
Optional<ParameterizedExpression> additionalPredicate)
{
String sql = "DELETE FROM " + getRelation(client, baseRelation.getRemoteTableName());

ImmutableList.Builder<String> conjuncts = ImmutableList.builder();
ImmutableList.Builder<QueryParameter> accumulator = ImmutableList.builder();

toConjuncts(client, session, connection, tupleDomain, conjuncts, accumulator::add);
additionalPredicate.ifPresent(conjuncts::add);
additionalPredicate.ifPresent(predicate -> {
conjuncts.add(predicate.expression());
accumulator.addAll(predicate.parameters());
});
List<String> clauses = conjuncts.build();
if (!clauses.isEmpty()) {
sql += " WHERE " + Joiner.on(" AND ").join(clauses);
Expand All @@ -182,7 +189,9 @@ public PreparedStatement prepareStatement(
for (int i = 0; i < parameters.size(); i++) {
QueryParameter parameter = parameters.get(i);
int parameterIndex = i + 1;
WriteFunction writeFunction = getWriteFunction(client, session, connection, parameter.getJdbcType(), parameter.getType());
WriteFunction writeFunction = parameter.getJdbcType()
.map(jdbcType -> getWriteFunction(client, session, connection, jdbcType, parameter.getType()))
.orElseGet(() -> getWriteFunction(client, session, parameter.getType()));
Class<?> javaType = writeFunction.getJavaType();
Object value = parameter.getValue()
// The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and
Expand Down Expand Up @@ -251,21 +260,24 @@ protected String getRelation(JdbcClient client, RemoteTableName remoteTableName)
return client.quoted(remoteTableName);
}

protected String getProjection(JdbcClient client, List<JdbcColumnHandle> columns, Map<String, String> columnExpressions)
protected String getProjection(JdbcClient client, List<JdbcColumnHandle> columns, Map<String, ParameterizedExpression> columnExpressions, Consumer<QueryParameter> accumulator)
{
if (columns.isEmpty()) {
return "1 x";
}
return columns.stream()
.map(jdbcColumnHandle -> {
String columnAlias = client.quoted(jdbcColumnHandle.getColumnName());
String expression = columnExpressions.get(jdbcColumnHandle.getColumnName());
if (expression == null) {
return columnAlias;
}
return format("%s AS %s", expression, columnAlias);
})
.collect(joining(", "));
List<String> projections = new ArrayList<>();
for (JdbcColumnHandle jdbcColumnHandle : columns) {
String columnAlias = client.quoted(jdbcColumnHandle.getColumnName());
ParameterizedExpression expression = columnExpressions.get(jdbcColumnHandle.getColumnName());
if (expression == null) {
projections.add(columnAlias);
}
else {
projections.add(format("%s AS %s", expression.expression(), columnAlias));
expression.parameters().forEach(accumulator);
}
}
return String.join(", ", projections);
}

private String getFrom(JdbcClient client, JdbcRelationHandle baseRelation, Consumer<QueryParameter> accumulator)
Expand Down Expand Up @@ -425,4 +437,9 @@ private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSessio
verify(writeFunction.getJavaType() == type.getJavaType(), "Java type mismatch: %s, %s", writeFunction, type);
return writeFunction;
}

private static WriteFunction getWriteFunction(JdbcClient client, ConnectorSession session, Type type)
{
return client.toWriteMapping(session, type).getWriteFunction();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.jdbc;

import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -127,7 +128,7 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
}

@Override
public Optional<String> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
public Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return delegate().convertPredicate(session, expression, assignments);
}
Expand Down Expand Up @@ -158,7 +159,7 @@ public PreparedQuery prepareQuery(
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions)
Map<String, ParameterizedExpression> columnExpressions)
{
return delegate().prepareQuery(session, table, groupingSets, columns, columnExpressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.jdbc;

import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
Expand Down Expand Up @@ -78,7 +79,7 @@ default Optional<JdbcExpression> implementAggregation(ConnectorSession session,
return Optional.empty();
}

default Optional<String> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
default Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return Optional.empty();
}
Expand All @@ -99,7 +100,7 @@ PreparedQuery prepareQuery(
JdbcTableHandle table,
Optional<List<List<JdbcColumnHandle>>> groupingSets,
List<JdbcColumnHandle> columns,
Map<String, String> columnExpressions);
Map<String, ParameterizedExpression> columnExpressions);

PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
*/
package io.trino.plugin.jdbc;

import com.google.common.collect.ImmutableList;

import java.util.List;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public final class JdbcExpression
{
private final String expression;
private final List<QueryParameter> parameters;
private final JdbcTypeHandle jdbcTypeHandle;

public JdbcExpression(String expression, JdbcTypeHandle jdbcTypeHandle)
public JdbcExpression(String expression, List<QueryParameter> parameters, JdbcTypeHandle jdbcTypeHandle)
{
this.expression = requireNonNull(expression, "expression is null");
this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"));
this.jdbcTypeHandle = requireNonNull(jdbcTypeHandle, "jdbcTypeHandle is null");
}

Expand All @@ -32,6 +38,11 @@ public String getExpression()
return expression;
}

public List<QueryParameter> getParameters()
{
return parameters;
}

public JdbcTypeHandle getJdbcTypeHandle()
{
return jdbcTypeHandle;
Expand All @@ -42,6 +53,7 @@ public String toString()
{
return toStringHelper(this)
.add("expression", expression)
.add("parameters", parameters)
.add("jdbcTypeHandle", jdbcTypeHandle)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.SchemaTableName;
Expand All @@ -40,7 +41,7 @@ public final class JdbcTableHandle

private final TupleDomain<ColumnHandle> constraint;
// Additional to constraint
private final List<String> constraintExpressions;
private final List<ParameterizedExpression> constraintExpressions;

// semantically sort order is applied after constraint
private final Optional<List<JdbcSortItem>> sortOrder;
Expand Down Expand Up @@ -78,7 +79,7 @@ public JdbcTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTa
public JdbcTableHandle(
@JsonProperty("relationHandle") JdbcRelationHandle relationHandle,
@JsonProperty("constraint") TupleDomain<ColumnHandle> constraint,
@JsonProperty("constraintExpressions") List<String> constraintExpressions,
@JsonProperty("constraintExpressions") List<ParameterizedExpression> constraintExpressions,
@JsonProperty("sortOrder") Optional<List<JdbcSortItem>> sortOrder,
@JsonProperty("limit") OptionalLong limit,
@JsonProperty("columns") Optional<List<JdbcColumnHandle>> columns,
Expand Down Expand Up @@ -138,7 +139,7 @@ public TupleDomain<ColumnHandle> getConstraint()
}

@JsonProperty
public List<String> getConstraintExpressions()
public List<ParameterizedExpression> getConstraintExpressions()
{
return constraintExpressions;
}
Expand Down
Loading

0 comments on commit e3bbf6c

Please sign in to comment.