Skip to content

Commit

Permalink
Enable push down updates into connector for all jdbc connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-lyutenko committed Sep 8, 2023
1 parent 462e026 commit 57cd7af
Show file tree
Hide file tree
Showing 34 changed files with 643 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, PreparedQuery pr
// The query is opaque, so we don't know referenced tables
Optional.empty(),
0,
Optional.empty());
Optional.empty(),
ImmutableList.of());
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, "Failed to get table handle for prepared query. " + firstNonNull(e.getMessage(), e), e);
Expand Down Expand Up @@ -1328,6 +1329,7 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle)
checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle);
checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle);
checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle);
checkArgument(handle.getUpdateAssignments().isEmpty(), "Unable to delete when update assignments are set: %s", handle);
verify(handle.getAuthorization().isEmpty(), "Unexpected authorization is required for table: %s".formatted(handle));
try (Connection connection = connectionFactory.openConnection(session)) {
verify(connection.getAutoCommit());
Expand All @@ -1347,6 +1349,33 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle)
}
}

@Override
public OptionalLong update(ConnectorSession session, JdbcTableHandle handle)
{
checkArgument(handle.isNamedRelation(), "Unable to update from synthetic table: %s", handle);
checkArgument(handle.getLimit().isEmpty(), "Unable to update when limit is set: %s", handle);
checkArgument(handle.getSortOrder().isEmpty(), "Unable to update when sort order is set: %s", handle);
checkArgument(!handle.getUpdateAssignments().isEmpty(), "Unable to update when update assignments are not set: %s", handle);
verify(handle.getAuthorization().isEmpty(), "Unexpected authorization is required for table: %s".formatted(handle));
try (Connection connection = connectionFactory.openConnection(session)) {
verify(connection.getAutoCommit());
PreparedQuery preparedQuery = queryBuilder.prepareUpdateQuery(
this,
session,
connection,
handle.getRequiredNamedRelation(),
handle.getConstraint(),
getAdditionalPredicate(handle.getConstraintExpressions(), Optional.empty()),
handle.getUpdateAssignments());
try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery, Optional.empty())) {
return OptionalLong.of(preparedStatement.executeUpdate());
}
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

@Override
public void truncateTable(ConnectorSession session, JdbcTableHandle handle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,14 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle)
return deletedRowsCount;
}

@Override
public OptionalLong update(ConnectorSession session, JdbcTableHandle handle)
{
OptionalLong updatedRowsCount = delegate.update(session, handle);
onDataChanged(handle.getRequiredNamedRelation().getSchemaTableName());
return updatedRowsCount;
}

@Override
public void truncateTable(ConnectorSession session, JdbcTableHandle handle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.trino.spi.connector.LimitApplicationResult;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.RetryMode;
import io.trino.spi.connector.RowChangeParadigm;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
import io.trino.spi.connector.SortItem;
Expand Down Expand Up @@ -96,6 +97,7 @@
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -241,7 +243,8 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
handle.getColumns(),
handle.getOtherReferencedTables(),
handle.getNextSyntheticColumnId(),
handle.getAuthorization());
handle.getAuthorization(),
handle.getUpdateAssignments());

return Optional.of(
remainingExpression.isPresent()
Expand All @@ -263,7 +266,8 @@ private JdbcTableHandle flushAttributesAsQuery(ConnectorSession session, JdbcTab
Optional.of(columns),
handle.getAllReferencedTables(),
handle.getNextSyntheticColumnId(),
handle.getAuthorization());
handle.getAuthorization(),
handle.getUpdateAssignments());
}

@Override
Expand Down Expand Up @@ -309,7 +313,8 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
Optional.of(newColumns),
handle.getOtherReferencedTables(),
handle.getNextSyntheticColumnId(),
handle.getAuthorization()),
handle.getAuthorization(),
handle.getUpdateAssignments()),
projections,
assignments.entrySet().stream()
.map(assignment -> new Assignment(
Expand Down Expand Up @@ -419,7 +424,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
Optional.of(newColumnsList),
handle.getAllReferencedTables(),
nextSyntheticColumnId,
handle.getAuthorization());
handle.getAuthorization(),
handle.getUpdateAssignments());

return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), precalculateStatisticsForPushdown));
}
Expand Down Expand Up @@ -514,7 +520,8 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
.addAll(rightReferencedTables)
.build())),
nextSyntheticColumnId,
leftHandle.getAuthorization()),
leftHandle.getAuthorization(),
leftHandle.getUpdateAssignments()),
ImmutableMap.copyOf(newLeftColumns),
ImmutableMap.copyOf(newRightColumns),
precalculateStatisticsForPushdown));
Expand Down Expand Up @@ -576,7 +583,8 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
handle.getColumns(),
handle.getOtherReferencedTables(),
handle.getNextSyntheticColumnId(),
handle.getAuthorization());
handle.getAuthorization(),
handle.getUpdateAssignments());

return Optional.of(new LimitApplicationResult<>(handle, jdbcClient.isLimitGuaranteed(session), precalculateStatisticsForPushdown));
}
Expand Down Expand Up @@ -630,7 +638,8 @@ public Optional<TopNApplicationResult<ConnectorTableHandle>> applyTopN(
handle.getColumns(),
handle.getOtherReferencedTables(),
handle.getNextSyntheticColumnId(),
handle.getAuthorization());
handle.getAuthorization(),
handle.getUpdateAssignments());

return Optional.of(new TopNApplicationResult<>(sortedTableHandle, jdbcClient.isTopNGuaranteed(session), precalculateStatisticsForPushdown));
}
Expand Down Expand Up @@ -899,6 +908,24 @@ public OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandle
return jdbcClient.delete(session, (JdbcTableHandle) handle);
}

@Override
public Optional<ConnectorTableHandle> applyUpdate(ConnectorSession session, ConnectorTableHandle handle, Map<ColumnHandle, Constant> assignments)
{
return Optional.of(((JdbcTableHandle) handle).withAssignments(assignments));
}

@Override
public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle)
{
return CHANGE_ONLY_UPDATED_COLUMNS;
}

@Override
public OptionalLong executeUpdate(ConnectorSession session, ConnectorTableHandle handle)
{
return jdbcClient.update(session, (JdbcTableHandle) handle);
}

@Override
public void truncateTable(ConnectorSession session, ConnectorTableHandle tableHandle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,57 @@ public PreparedQuery prepareDeleteQuery(
return new PreparedQuery(sql, accumulator.build());
}

@Override
public PreparedQuery prepareUpdateQuery(
JdbcClient client,
ConnectorSession session,
Connection connection,
JdbcNamedRelationHandle baseRelation,
TupleDomain<ColumnHandle> tupleDomain,
Optional<ParameterizedExpression> additionalPredicate,
List<JdbcAssignmentItem> assignments)
{
ImmutableList.Builder<QueryParameter> accumulator = ImmutableList.builder();

String sql = "UPDATE " + getRelation(client, baseRelation.getRemoteTableName()) + " SET ";

assignments.forEach(entry -> {
JdbcColumnHandle columnHandle = entry.column();
accumulator.add(
new QueryParameter(
columnHandle.getJdbcTypeHandle(),
columnHandle.getColumnType(),
entry.queryParameter().getValue()));
});

sql += assignments.stream()
.map(JdbcAssignmentItem::column)
.map(columnHandle -> {
String bindExpression = getWriteFunction(
client,
session,
connection,
columnHandle.getJdbcTypeHandle(),
columnHandle.getColumnType())
.getBindExpression();
return client.quoted(columnHandle.getColumnName()) + " = " + bindExpression;
})
.collect(joining(", "));

ImmutableList.Builder<String> conjuncts = ImmutableList.builder();

toConjuncts(client, session, connection, tupleDomain, conjuncts, accumulator::add);
additionalPredicate.ifPresent(predicate -> {
conjuncts.add(predicate.expression());
accumulator.addAll(predicate.parameters());
});
List<String> clauses = conjuncts.build();
if (!clauses.isEmpty()) {
sql += " WHERE " + Joiner.on(" AND ").join(clauses);
}
return new PreparedQuery(sql, accumulator.build());
}

@Override
public PreparedStatement prepareStatement(
JdbcClient client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle)
return delegate().delete(session, handle);
}

@Override
public OptionalLong update(ConnectorSession session, JdbcTableHandle handle)
{
return delegate().update(session, handle);
}

@Override
public void truncateTable(ConnectorSession session, JdbcTableHandle handle)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import com.fasterxml.jackson.annotation.JsonProperty;

import static java.util.Objects.requireNonNull;

public record JdbcAssignmentItem(@JsonProperty("column") JdbcColumnHandle column, @JsonProperty("queryParameter") QueryParameter queryParameter)
{
public JdbcAssignmentItem
{
requireNonNull(column, "column is null");
requireNonNull(queryParameter, "queryParameter is null");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,7 @@ default Optional<TableScanRedirectApplicationResult> getTableScanRedirection(Con

void truncateTable(ConnectorSession session, JdbcTableHandle handle);

OptionalLong update(ConnectorSession session, JdbcTableHandle handle);

OptionalInt getMaxWriteParallelism(ConnectorSession session);
}
Loading

0 comments on commit 57cd7af

Please sign in to comment.