From 75729d572a0eac4f75ae8f785ca2103f5777ffd5 Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Thu, 30 May 2024 19:47:14 +0800 Subject: [PATCH 1/2] Support UPDATE and MERGE for Mongo connector --- docs/src/main/sphinx/connector/mongodb.md | 2 + .../trino/plugin/mongodb/MongoMergeSink.java | 309 ++++++++++++++++++ .../plugin/mongodb/MongoMergeTableHandle.java | 77 +++++ .../trino/plugin/mongodb/MongoMetadata.java | 53 ++- .../trino/plugin/mongodb/MongoPageSink.java | 10 +- .../plugin/mongodb/MongoPageSinkProvider.java | 17 + .../mongodb/BaseMongoConnectorSmokeTest.java | 4 +- .../mongodb/BaseMongoFailureRecoveryTest.java | 29 +- .../mongodb/TestMongoConnectorTest.java | 68 ++-- 9 files changed, 508 insertions(+), 61 deletions(-) create mode 100644 plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java create mode 100644 plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java diff --git a/docs/src/main/sphinx/connector/mongodb.md b/docs/src/main/sphinx/connector/mongodb.md index ebdd4e219991..83306597e228 100644 --- a/docs/src/main/sphinx/connector/mongodb.md +++ b/docs/src/main/sphinx/connector/mongodb.md @@ -476,6 +476,8 @@ statements, the connector supports the following features: - {doc}`/sql/insert` - {doc}`/sql/delete` +- {doc}`/sql/update` +- {doc}`/sql/merge` - {doc}`/sql/create-table` - {doc}`/sql/create-table-as` - {doc}`/sql/drop-table` diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java new file mode 100644 index 000000000000..e96998aa5e1c --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java @@ -0,0 +1,309 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.BulkWriteOptions; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.WriteModel; +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSinkId; +import org.bson.Document; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.mongodb.client.model.Filters.in; +import static io.trino.plugin.mongodb.MongoMetadata.MERGE_ROW_ID_BASE_NAME; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.util.Objects.requireNonNull; +import static org.weakref.jmx.$internal.guava.collect.ImmutableList.toImmutableList; + +public class MongoMergeSink + implements ConnectorMergeSink +{ + private static final String SET = "$set"; + + private final int columnCount; + + private final ConnectorPageSink insertSink; + private final ConnectorPageSink deleteSink; + + private final Map> updateSinkSuppliers; + private final Map updateCaseChannels; + + public MongoMergeSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + Map> updateCaseColumns, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + requireNonNull(mongoSession, "mongoSession is null"); + requireNonNull(remoteTableName, "remoteTableName is null"); + requireNonNull(columns, "columns is null"); + requireNonNull(updateCaseColumns, "updateCaseColumns is null"); + requireNonNull(pageSinkId, "pageSinkId is null"); + requireNonNull(mergeColumnHandle, "mergeColumnHandle is null"); + + this.columnCount = columns.size(); + + this.insertSink = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.deleteSink = new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId); + + ImmutableMap.Builder> updateSinksBuilder = ImmutableMap.builder(); + ImmutableMap.Builder updateCaseChannelsBuilder = ImmutableMap.builder(); + for (Map.Entry> entry : updateCaseColumns.entrySet()) { + int caseNumber = entry.getKey(); + List updateColumns = entry.getValue().stream() + .map(MongoColumnHandle.class::cast) + .collect(toImmutableList()); + Set columnChannels = updateColumns.stream() + .map(columns::indexOf) + .collect(toImmutableSet()); + Supplier updateSupplier = Suppliers.memoize(() -> createUpdateSink( + mongoSession, + remoteTableName, + updateColumns, + mergeColumnHandle, + implicitPrefix, + pageSinkIdColumnName, + pageSinkId)); + updateSinksBuilder.put(caseNumber, updateSupplier); + updateCaseChannelsBuilder.put(caseNumber, columnChannels.stream().mapToInt(Integer::intValue).sorted().toArray()); + } + this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow(); + this.updateCaseChannels = updateCaseChannelsBuilder.buildOrThrow(); + } + + @Override + public void storeMergedRows(Page page) + { + checkArgument(page.getChannelCount() == 3 + columnCount, "The page size should be 3 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + int positionCount = page.getPositionCount(); + Block operationBlock = page.getBlock(columnCount); + Block rowId = page.getBlock(columnCount + 2); + + int[] dataChannel = IntStream.range(0, columnCount + 1).toArray(); + dataChannel[columnCount] = columnCount + 2; + Page dataPage = page.getColumns(dataChannel); + + int[] insertPositions = new int[positionCount]; + int insertPositionCount = 0; + int[] deletePositions = new int[positionCount]; + int deletePositionCount = 0; + + Block updateCaseBlock = page.getBlock(columnCount + 1); + Map updatePositions = new HashMap<>(); + Map updatePositionCounts = new HashMap<>(); + + for (int position = 0; position < positionCount; position++) { + int operation = TINYINT.getByte(operationBlock, position); + switch (operation) { + case INSERT_OPERATION_NUMBER -> { + insertPositions[insertPositionCount] = position; + insertPositionCount++; + } + case DELETE_OPERATION_NUMBER -> { + deletePositions[deletePositionCount] = position; + deletePositionCount++; + } + case UPDATE_OPERATION_NUMBER -> { + int caseNumber = INTEGER.getInt(updateCaseBlock, position); + int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0); + updatePositions.computeIfAbsent(caseNumber, _ -> new int[positionCount])[updatePositionCount] = position; + updatePositionCounts.put(caseNumber, updatePositionCount + 1); + } + default -> throw new IllegalStateException("Unexpected value: " + operation); + } + } + + if (deletePositionCount > 0) { + Block positions = rowId.getPositions(deletePositions, 0, deletePositionCount); + deleteSink.appendPage(new Page(deletePositionCount, positions)); + } + + for (Map.Entry entry : updatePositionCounts.entrySet()) { + int caseNumber = entry.getKey(); + int updatePositionCount = entry.getValue(); + if (updatePositionCount > 0) { + checkArgument(updatePositions.containsKey(caseNumber), "Unexpected case number %s", caseNumber); + int[] positions = updatePositions.get(caseNumber); + int[] updateAssignmentChannels = updateCaseChannels.get(caseNumber); + Block[] updateBlocks = new Block[updateAssignmentChannels.length + 1]; + for (int channel = 0; channel < updateAssignmentChannels.length; channel++) { + updateBlocks[channel] = dataPage.getBlock(updateAssignmentChannels[channel]).getPositions(positions, 0, updatePositionCount); + } + updateBlocks[updateAssignmentChannels.length] = rowId.getPositions(positions, 0, updatePositionCount); + + updateSinkSuppliers.get(caseNumber).get().appendPage(new Page(updatePositionCount, updateBlocks)); + } + } + + if (insertPositionCount > 0) { + // Insert page should not include _id column by default, unless the insert columns include it explicitly + Page insertPage = dataPage.getColumns(Arrays.copyOf(dataChannel, columnCount)); + insertSink.appendPage(insertPage.getPositions(insertPositions, 0, insertPositionCount)); + } + } + + @Override + public CompletableFuture> finish() + { + CompletableFuture> finish = insertSink.finish(); + deleteSink.finish(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish); + + return finish; + } + + @Override + public void abort() + { + insertSink.abort(); + deleteSink.abort(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort); + } + + private static ConnectorPageSink createUpdateSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + Collection columns, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + // Update should always include id column explicitly + List updateColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) + .addAll(columns) + .add(mergeColumnHandle) + .build(); + return new MongoUpdateSink(mongoSession, remoteTableName, updateColumns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + private static class MongoUpdateSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoUpdateSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + List> bulkWrites = new ArrayList<>(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + Document filter = new Document(MERGE_ROW_ID_BASE_NAME, document.get(MERGE_ROW_ID_BASE_NAME)); + bulkWrites.add(new UpdateOneModel<>(filter, new Document(SET, document))); + } + collection.bulkWrite(bulkWrites, new BulkWriteOptions().ordered(false)); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } + + private static class MongoDeleteSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoDeleteSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + ImmutableList.Builder idsToDeleteBuilder = ImmutableList.builder(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + idsToDeleteBuilder.add(document.get(MERGE_ROW_ID_BASE_NAME)); + } + collection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDeleteBuilder.build())); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java new file mode 100644 index 000000000000..a45029db15fd --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public record MongoMergeTableHandle( + RemoteTableName remoteTableName, + List columns, + Map> updateCaseColumns, + MongoColumnHandle mergeRowIdColumn, + Optional filter, + TupleDomain constraint, + Optional temporaryTableName, + Optional pageSinkIdColumnName) + implements ConnectorMergeTableHandle +{ + public MongoMergeTableHandle + { + requireNonNull(remoteTableName, "remoteTableName is null"); + columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + updateCaseColumns = ImmutableMap.copyOf(requireNonNull(updateCaseColumns, "updateCaseColumns is null")); + requireNonNull(filter, "filter is null"); + requireNonNull(mergeRowIdColumn, "mergeRowIdColumn is null"); + requireNonNull(constraint, "constraint is null"); + requireNonNull(temporaryTableName, "temporaryTableName is null"); + requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null"); + checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(), + "temporaryTableName.isPresent is not equal to pageSinkIdColumnName.isPresent"); + } + + @JsonIgnore + public Optional getTemporaryRemoteTableName() + { + return temporaryTableName.map(tableName -> new RemoteTableName(remoteTableName.databaseName(), tableName)); + } + + @Override + public ConnectorTableHandle getTableHandle() + { + return new MongoTableHandle( + new SchemaTableName(remoteTableName.databaseName(), remoteTableName.collectionName()), + remoteTableName, + filter, + constraint, + ImmutableSet.of(), + OptionalInt.empty()); + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index e023e61e7211..cf7a72e782da 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -29,6 +29,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -46,6 +47,7 @@ import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RelationColumnsMetadata; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SaveMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -107,6 +109,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RelationColumnsMetadata.forTable; +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; import static io.trino.spi.connector.SaveMode.REPLACE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -126,6 +129,8 @@ public class MongoMetadata implements ConnectorMetadata { + public static final String MERGE_ROW_ID_BASE_NAME = "_id"; + private static final Logger log = Logger.get(MongoMetadata.class); private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT; @@ -544,10 +549,56 @@ private void finishInsert( } } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) + { + if (retryMode != RetryMode.NO_RETRIES) { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support MERGE with retries"); + } + + MongoTableHandle table = (MongoTableHandle) tableHandle; + MongoInsertTableHandle insertTableHandle = (MongoInsertTableHandle) beginInsert(session, tableHandle, ImmutableList.of(), retryMode); + return new MongoMergeTableHandle( + insertTableHandle.remoteTableName(), + insertTableHandle.columns(), + updateCaseColumns, + (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle), + table.filter(), + table.constraint(), + insertTableHandle.temporaryTableName(), + insertTableHandle.pageSinkIdColumnName()); + } + + @Override + public void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle mergeTableHandle, + List sourceTableHandles, + Collection fragments, + Collection computedStatistics) + { + MongoMergeTableHandle tableHandle = (MongoMergeTableHandle) mergeTableHandle; + MongoInsertTableHandle insertTableHandle = new MongoInsertTableHandle( + tableHandle.remoteTableName(), + ImmutableList.copyOf(tableHandle.columns()), + tableHandle.temporaryTableName(), + tableHandle.pageSinkIdColumnName()); + finishInsert(session, insertTableHandle, ImmutableList.of(), fragments, computedStatistics); + } + @Override public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty()); + Map columnHandles = getColumnHandles(session, tableHandle); + checkState(columnHandles.containsKey(MERGE_ROW_ID_BASE_NAME), "id column %s not exists", MERGE_ROW_ID_BASE_NAME); + Type idColumnType = ((MongoColumnHandle) columnHandles.get(MERGE_ROW_ID_BASE_NAME)).type(); + return new MongoColumnHandle(MERGE_ROW_ID_BASE_NAME, ImmutableList.of(), idColumnType, true, false, Optional.empty()); } @Override diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index 06496aac7b09..7d21e426c2b8 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -110,6 +110,12 @@ public MongoPageSink( public CompletableFuture appendPage(Page page) { MongoCollection collection = mongoSession.getCollection(remoteTableName); + collection.insertMany(buildBatchDocumentsFromPage(page), new InsertManyOptions().ordered(true)); + return NOT_BLOCKED; + } + + protected List buildBatchDocumentsFromPage(Page page) + { List batch = new ArrayList<>(page.getPositionCount()); for (int position = 0; position < page.getPositionCount(); position++) { @@ -122,9 +128,7 @@ public CompletableFuture appendPage(Page page) } batch.add(doc); } - - collection.insertMany(batch, new InsertManyOptions().ordered(true)); - return NOT_BLOCKED; + return batch; } private Object getObjectValue(Type type, Block block, int position) diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java index 7ebf8667b4de..ee35b8e292e4 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java @@ -15,6 +15,8 @@ import com.google.inject.Inject; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkId; @@ -50,4 +52,19 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa MongoInsertTableHandle handle = (MongoInsertTableHandle) insertTableHandle; return new MongoPageSink(mongoSession, handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), handle.columns(), implicitPrefix, handle.pageSinkIdColumnName(), pageSinkId); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle, ConnectorPageSinkId pageSinkId) + { + MongoMergeTableHandle handle = (MongoMergeTableHandle) mergeHandle; + return new MongoMergeSink( + mongoSession, + handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), + handle.columns(), + handle.updateCaseColumns(), + handle.mergeRowIdColumn(), + implicitPrefix, + handle.pageSinkIdColumnName(), + pageSinkId); + } } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java index f39404a42ea0..27ad160ab8fc 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java @@ -30,11 +30,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return switch (connectorBehavior) { case SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_SCHEMA, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java index 8b1d9ca99d1b..327365ab3ffd 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assumptions.abort; @@ -69,24 +70,32 @@ protected void testAnalyzeTable() @Override protected void testDelete() { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // This simple delete on Mongo ends up as a very simple, single-fragment, coordinator-only plan, + // which has no ability to recover from errors. This test simply verifies that's still the case. + Optional setupQuery = Optional.of("CREATE TABLE AS SELECT * FROM orders"); + String testQuery = "DELETE FROM
WHERE orderkey = 1"; + Optional cleanupQuery = Optional.of("DROP TABLE
"); + + assertThatQuery(testQuery) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .isCoordinatorOnly(); } @Test @Override protected void testDeleteWithSubquery() { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @Override protected void testMerge() { - assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @@ -102,16 +111,16 @@ protected void testRefreshMaterializedView() @Override protected void testUpdate() { - assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @Override protected void testUpdateWithSubquery() { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); } @Override diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index fd85c4abb8a6..a90a4da19965 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -53,7 +53,6 @@ import static com.mongodb.client.model.CollationStrength.PRIMARY; import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoClient; import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType; -import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; @@ -103,13 +102,11 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, SUPPORTS_DROP_FIELD, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_FIELD, SUPPORTS_RENAME_SCHEMA, SUPPORTS_SET_FIELD_TYPE, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } @@ -120,6 +117,29 @@ protected TestTable createTableWithDefaultColumns() return abort("MongoDB connector does not support column default values"); } + @Test + public void testMergeWithCustomizeTypeIdColumn() + { + String targetTable = "merge_with_customize_type_id_column_target_" + randomNameSuffix(); + String sourceTable = "merget_with_customize_type_id_column_source_" + randomNameSuffix(); + createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.empty()); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 5, 'Antioch'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 3, 'Cambridge'), (4, 'Dave', 11, 'Devon')".formatted(targetTable), 4); + + createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 6, 'Arches'), (3, 'Ed', 7, 'Etherville'), (4, 'Carol', 9, 'Centreville'), (7, 'Dave', 11, 'Darbyshire')".formatted(sourceTable), 4); + + assertUpdate("MERGE INTO %s t USING %s s ON (t._id = s._id)".formatted(targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (_id, customer, purchases, address) VALUES(s._id, s.customer, s.purchases, s.address)", 4); + assertQuery("SELECT _id, customer, purchases, address FROM " + targetTable, "VALUES (1, 'Aaron', 11, 'Arches'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 10, 'Etherville'), (7, 'Dave', 11, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + @Test @Override public void testColumnName() @@ -295,46 +315,6 @@ public void testInsertWithEveryType() assertThat(getQueryRunner().tableExists(getSession(), tableName)).isFalse(); } - @Test - @Override - public void testDeleteWithComplexPredicate() - { - assertThatThrownBy(super::testDeleteWithComplexPredicate) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithLike() - { - assertThatThrownBy(super::testDeleteWithLike) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSemiJoin() - { - assertThatThrownBy(super::testDeleteWithSemiJoin) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSubquery() - { - assertThatThrownBy(super::testDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testExplainAnalyzeWithDeleteWithSubquery() - { - assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - @Test public void testPredicatePushdown() { From bec51767ca754289969d754fda310d2a9b8e01aa Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Mon, 10 Jun 2024 19:02:20 +0800 Subject: [PATCH 2/2] Support fte MERGE for Mongo connector --- .../trino/plugin/mongodb/MongoMergeSink.java | 58 +++- .../plugin/mongodb/MongoMergeTableHandle.java | 4 + .../trino/plugin/mongodb/MongoMetadata.java | 271 ++++++++++++++++-- .../plugin/mongodb/MongoPageSinkProvider.java | 2 + .../mongodb/BaseMongoFailureRecoveryTest.java | 32 --- 5 files changed, 301 insertions(+), 66 deletions(-) diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java index e96998aa5e1c..573a8616969c 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java @@ -42,6 +42,7 @@ import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.mongodb.client.model.Filters.in; import static io.trino.plugin.mongodb.MongoMetadata.MERGE_ROW_ID_BASE_NAME; @@ -69,6 +70,8 @@ public MongoMergeSink( List columns, Map> updateCaseColumns, MongoColumnHandle mergeColumnHandle, + Optional deleteOutputTableHandle, + Optional updateOutputTableHandle, String implicitPrefix, Optional pageSinkIdColumnName, ConnectorPageSinkId pageSinkId) @@ -83,7 +86,7 @@ public MongoMergeSink( this.columnCount = columns.size(); this.insertSink = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); - this.deleteSink = new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.deleteSink = createDeleteSink(mongoSession, deleteOutputTableHandle, remoteTableName, mergeColumnHandle, implicitPrefix, pageSinkIdColumnName, pageSinkId); ImmutableMap.Builder> updateSinksBuilder = ImmutableMap.builder(); ImmutableMap.Builder updateCaseChannelsBuilder = ImmutableMap.builder(); @@ -97,6 +100,7 @@ public MongoMergeSink( .collect(toImmutableSet()); Supplier updateSupplier = Suppliers.memoize(() -> createUpdateSink( mongoSession, + updateOutputTableHandle, remoteTableName, updateColumns, mergeColumnHandle, @@ -201,6 +205,7 @@ public void abort() private static ConnectorPageSink createUpdateSink( MongoSession mongoSession, + Optional outputTableHandle, RemoteTableName remoteTableName, Collection columns, MongoColumnHandle mergeColumnHandle, @@ -208,12 +213,55 @@ private static ConnectorPageSink createUpdateSink( Optional pageSinkIdColumnName, ConnectorPageSinkId pageSinkId) { - // Update should always include id column explicitly - List updateColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) + if (outputTableHandle.isEmpty()) { + // Update should always include id column explicitly + List updateColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) + .addAll(columns) + .add(mergeColumnHandle) + .build(); + return new MongoUpdateSink(mongoSession, remoteTableName, updateColumns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + MongoOutputTableHandle mongoOutputTableHandle = outputTableHandle.get(); + checkState(mongoOutputTableHandle.getTemporaryRemoteTableName().isPresent(), "temporary table not exist"); + + List updateColumns = ImmutableList.builder() .addAll(columns) - .add(mergeColumnHandle) + .addAll(mongoOutputTableHandle.columns()) .build(); - return new MongoUpdateSink(mongoSession, remoteTableName, updateColumns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + + return new MongoPageSink( + mongoSession, + mongoOutputTableHandle.getTemporaryRemoteTableName().get(), + updateColumns, + implicitPrefix, + pageSinkIdColumnName, + pageSinkId); + } + + private static ConnectorPageSink createDeleteSink( + MongoSession mongoSession, + Optional outputTableHandle, + RemoteTableName remoteTableName, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + if (outputTableHandle.isEmpty()) { + return new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + MongoOutputTableHandle mongoOutputTableHandle = outputTableHandle.get(); + checkState(mongoOutputTableHandle.getTemporaryRemoteTableName().isPresent(), "temporary table not exist"); + + return new MongoPageSink( + mongoSession, + mongoOutputTableHandle.getTemporaryRemoteTableName().get(), + mongoOutputTableHandle.columns(), + implicitPrefix, + pageSinkIdColumnName, + pageSinkId); } private static class MongoUpdateSink diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java index a45029db15fd..9f6f092923d0 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java @@ -39,6 +39,8 @@ public record MongoMergeTableHandle( MongoColumnHandle mergeRowIdColumn, Optional filter, TupleDomain constraint, + Optional deleteOutputTableHandle, + Optional updateOutputTableHandle, Optional temporaryTableName, Optional pageSinkIdColumnName) implements ConnectorMergeTableHandle @@ -51,6 +53,8 @@ public record MongoMergeTableHandle( requireNonNull(filter, "filter is null"); requireNonNull(mergeRowIdColumn, "mergeRowIdColumn is null"); requireNonNull(constraint, "constraint is null"); + requireNonNull(deleteOutputTableHandle, "deleteOutputTableHandle is null"); + requireNonNull(updateOutputTableHandle, "updateOutputTableHandle is null"); requireNonNull(temporaryTableName, "temporaryTableName is null"); requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null"); checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(), diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index cf7a72e782da..626ad7d03e15 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -19,6 +19,7 @@ import com.google.common.collect.Streams; import com.google.common.io.Closer; import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.MergeOptions; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.plugin.base.projection.ApplyProjectionUtil; @@ -73,6 +74,7 @@ import org.bson.Document; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -87,16 +89,21 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.MoreCollectors.onlyElement; +import static com.mongodb.client.model.Aggregates.addFields; import static com.mongodb.client.model.Aggregates.lookup; import static com.mongodb.client.model.Aggregates.match; import static com.mongodb.client.model.Aggregates.merge; import static com.mongodb.client.model.Aggregates.project; +import static com.mongodb.client.model.Filters.in; import static com.mongodb.client.model.Filters.ne; import static com.mongodb.client.model.Projections.exclude; +import static com.mongodb.client.model.Projections.fields; import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName; import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; @@ -138,7 +145,9 @@ public class MongoMetadata private final MongoSession mongoSession; - private final AtomicReference rollbackAction = new AtomicReference<>(); + private final AtomicReference insertRollbackAction = new AtomicReference<>(); + private final AtomicReference deleteRollbackAction = new AtomicReference<>(); + private final AtomicReference updateRollBackAction = new AtomicReference<>(); public MongoMetadata(MongoSession mongoSession) { @@ -411,7 +420,7 @@ public ConnectorOutputTableHandle beginCreateTable( Closer closer = Closer.create(); closer.register(() -> mongoSession.dropTable(remoteTableName)); - setRollback(() -> { + setInsertRollback(() -> { try { closer.close(); } @@ -451,7 +460,7 @@ public Optional finishCreateTable(ConnectorSession sess if (handle.temporaryTableName().isPresent()) { finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments); } - clearRollback(); + clearInsertRollback(); return Optional.empty(); } @@ -482,7 +491,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto RemoteTableName temporaryTable = new RemoteTableName(handle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); - setRollback(() -> mongoSession.dropTable(temporaryTable)); + setInsertRollback(() -> mongoSession.dropTable(temporaryTable)); return new MongoInsertTableHandle( handle.remoteTableName(), @@ -503,7 +512,7 @@ public Optional finishInsert( if (handle.temporaryTableName().isPresent()) { finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments); } - clearRollback(); + clearInsertRollback(); return Optional.empty(); } @@ -516,20 +525,8 @@ private void finishInsert( { Closer closer = Closer.create(); closer.register(() -> mongoSession.dropTable(temporaryTable)); - try { - // Create the temporary page sink ID table - RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.databaseName(), generateTemporaryTableName(session)); - MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); - mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty()); - closer.register(() -> mongoSession.dropTable(pageSinkIdsTable)); - - // Insert all the page sink IDs into the page sink ID table - MongoCollection pageSinkIdsCollection = mongoSession.getCollection(pageSinkIdsTable); - List pageSinkIds = fragments.stream() - .map(slice -> new Document(pageSinkIdColumnName, slice.getLong(0))) - .collect(toImmutableList()); - pageSinkIdsCollection.insertMany(pageSinkIds); + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); temporaryCollection.aggregate(ImmutableList.of( @@ -549,6 +546,171 @@ private void finishInsert( } } + private RemoteTableName getPageSinkIdsTable(ConnectorSession session, RemoteTableName temporaryTable, String pageSinkIdColumnName, Collection fragments, Closer closer) + { + // Create the temporary page sink ID table + RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.databaseName(), generateTemporaryTableName(session)); + MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); + mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty()); + closer.register(() -> mongoSession.dropTable(pageSinkIdsTable)); + + // Insert all the page sink IDs into the page sink ID table + MongoCollection pageSinkIdsCollection = mongoSession.getCollection(pageSinkIdsTable); + List pageSinkIds = fragments.stream() + .map(slice -> new Document(pageSinkIdColumnName, slice.getLong(0))) + .collect(toImmutableList()); + pageSinkIdsCollection.insertMany(pageSinkIds); + return pageSinkIdsTable; + } + + private Optional beginDelete(ConnectorSession session, MongoTableHandle tableHandle, RetryMode retryMode) + { + if (retryMode != RetryMode.RETRIES_ENABLED) { + return Optional.empty(); + } + + MongoColumnHandle rowIdColumn = (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle); + + String rowIdValueColumnName = getNonDuplicatedColumnName(MERGE_ROW_ID_BASE_NAME, ImmutableSet.of(MERGE_ROW_ID_BASE_NAME)); + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(ImmutableSet.of(MERGE_ROW_ID_BASE_NAME, rowIdValueColumnName)); + MongoColumnHandle rowIdValueColumn = new MongoColumnHandle(rowIdValueColumnName, ImmutableList.of(), rowIdColumn.type(), false, false, Optional.empty()); + List allColumns = ImmutableList.builder() + .add(rowIdValueColumn) + .add(pageSinkIdColumn) + .build(); + + RemoteTableName temporaryTable = new RemoteTableName(tableHandle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); + mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); + + setDeleteRollback(() -> mongoSession.dropTable(temporaryTable)); + + return Optional.of(new MongoOutputTableHandle( + tableHandle.remoteTableName(), + ImmutableList.of(rowIdValueColumn), + Optional.of(temporaryTable.collectionName()), + Optional.of(pageSinkIdColumn.baseName()))); + } + + private void finishDelete( + ConnectorSession session, + RemoteTableName targetTable, + RemoteTableName temporaryTable, + String pageSinkIdColumnName, + List columnHandles, + Collection fragments) + { + String rowIdName = getLast(columnHandles).baseName(); + checkArgument(!isNullOrEmpty(rowIdName) && rowIdName.startsWith(MERGE_ROW_ID_BASE_NAME), "Error rowId name: " + rowIdName); + + Closer closer = Closer.create(); + closer.register(() -> mongoSession.dropTable(temporaryTable)); + + try { + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); + + MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); + List idsToDelete = new ArrayList<>(); + temporaryCollection.aggregate(ImmutableList.of( + lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"), + match(ne("page_sink_id", ImmutableList.of())))) + .map(document -> document.get(rowIdName)).into(idsToDelete); + + MongoCollection targetCollection = mongoSession.getCollection(targetTable); + targetCollection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDelete)); + } + finally { + try { + closer.close(); + } + catch (IOException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + } + + private Optional beginUpdate(ConnectorSession session, MongoTableHandle tableHandle, RetryMode retryMode) + { + if (retryMode != RetryMode.RETRIES_ENABLED) { + return Optional.empty(); + } + + MongoColumnHandle rowIdColumn = (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle); + + MongoTable table = mongoSession.getTable(tableHandle.schemaTableName()); + List columns = table.columns(); + + Set allColumnNames = columns.stream() + .map(MongoColumnHandle::baseName) + .collect(toImmutableSet()); + + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(allColumnNames); + MongoColumnHandle rowIdValueColumn = new MongoColumnHandle( + getNonDuplicatedColumnName(MERGE_ROW_ID_BASE_NAME, allColumnNames), + ImmutableList.of(), + rowIdColumn.type(), + false, + false, + Optional.empty()); + + List allColumns = ImmutableList.builderWithExpectedSize(columns.size() + 2) + .addAll(columns) + .add(rowIdValueColumn) + .add(pageSinkIdColumn) + .build(); + RemoteTableName temporaryTable = new RemoteTableName(tableHandle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); + mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); + + setUpdateRollBack(() -> mongoSession.dropTable(temporaryTable)); + + return Optional.of(new MongoOutputTableHandle( + tableHandle.remoteTableName(), + ImmutableList.of(rowIdValueColumn), + Optional.of(temporaryTable.collectionName()), + Optional.of(pageSinkIdColumn.baseName()))); + } + + private void finishUpdate( + ConnectorSession session, + RemoteTableName targetTable, + RemoteTableName temporaryTable, + String pageSinkIdColumnName, + List columnHandles, + Collection fragments) + { + Closer closer = Closer.create(); + closer.register(() -> mongoSession.dropTable(temporaryTable)); + + String rowIdName = getLast(columnHandles).baseName(); + checkArgument(!isNullOrEmpty(rowIdName) && rowIdName.startsWith(MERGE_ROW_ID_BASE_NAME), "Error rowId name: " + rowIdName); + + try { + MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); + + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); + MergeOptions mergeOptions = new MergeOptions() + .whenMatched(MergeOptions.WhenMatched.MERGE) + .whenNotMatched(MergeOptions.WhenNotMatched.FAIL) + .uniqueIdentifier(MERGE_ROW_ID_BASE_NAME); + + temporaryCollection.aggregate(ImmutableList.of( + lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"), + match(ne("page_sink_id", ImmutableList.of())), + // Replace the unique key with the rowIdName reference value + addFields(new com.mongodb.client.model.Field<>(MERGE_ROW_ID_BASE_NAME, "$" + rowIdName)), + project(fields(exclude(rowIdName), exclude("page_sink_id"))), + merge(targetTable.collectionName(), mergeOptions))) + .toCollection(); + } + finally { + try { + closer.close(); + } + catch (IOException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + } + @Override public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) { @@ -558,12 +720,9 @@ public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, Connecto @Override public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) { - if (retryMode != RetryMode.NO_RETRIES) { - throw new TrinoException(NOT_SUPPORTED, "This connector does not support MERGE with retries"); - } - MongoTableHandle table = (MongoTableHandle) tableHandle; MongoInsertTableHandle insertTableHandle = (MongoInsertTableHandle) beginInsert(session, tableHandle, ImmutableList.of(), retryMode); + return new MongoMergeTableHandle( insertTableHandle.remoteTableName(), insertTableHandle.columns(), @@ -571,6 +730,8 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle), table.filter(), table.constraint(), + beginDelete(session, table, retryMode), + beginUpdate(session, table, retryMode), insertTableHandle.temporaryTableName(), insertTableHandle.pageSinkIdColumnName()); } @@ -590,6 +751,26 @@ public void finishMerge( tableHandle.temporaryTableName(), tableHandle.pageSinkIdColumnName()); finishInsert(session, insertTableHandle, ImmutableList.of(), fragments, computedStatistics); + + tableHandle.deleteOutputTableHandle().ifPresent(deleteOutputHandle -> + finishDelete( + session, + deleteOutputHandle.remoteTableName(), + deleteOutputHandle.getTemporaryRemoteTableName().get(), + deleteOutputHandle.pageSinkIdColumnName().get(), + deleteOutputHandle.columns(), + fragments)); + clearDeleteRollback(); + + tableHandle.updateOutputTableHandle().ifPresent(updateOutputHandle -> + finishUpdate( + session, + updateOutputHandle.remoteTableName(), + updateOutputHandle.getTemporaryRemoteTableName().get(), + updateOutputHandle.pageSinkIdColumnName().get(), + updateOutputHandle.columns(), + fragments)); + clearUpdateRollback(); } @Override @@ -885,19 +1066,41 @@ public Optional> applyTable return Optional.of(new TableFunctionApplicationResult<>(tableHandle, columnHandles)); } - private void setRollback(Runnable action) + private void setInsertRollback(Runnable action) { - checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set"); + checkState(insertRollbackAction.compareAndSet(null, action), "insertRollbackAction action is already set"); } - private void clearRollback() + private void clearInsertRollback() { - rollbackAction.set(null); + insertRollbackAction.set(null); + } + + private void setDeleteRollback(Runnable action) + { + checkState(deleteRollbackAction.compareAndSet(null, action), "deleteRollbackAction action is already set"); + } + + private void clearDeleteRollback() + { + deleteRollbackAction.set(null); + } + + private void setUpdateRollBack(Runnable action) + { + checkState(updateRollBackAction.compareAndSet(null, action), "updateRollBackAction action is already set"); + } + + private void clearUpdateRollback() + { + updateRollBackAction.set(null); } public void rollback() { - Optional.ofNullable(rollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(insertRollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(deleteRollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(updateRollBackAction.getAndSet(null)).ifPresent(Runnable::run); } private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) @@ -931,16 +1134,26 @@ private static void validateColumnNameForInsert(String columnName) } private static MongoColumnHandle buildPageSinkIdColumn(Set otherColumnNames) + { + return new MongoColumnHandle( + getNonDuplicatedColumnName("trino_page_sink_id", otherColumnNames), + ImmutableList.of(), + TRINO_PAGE_SINK_ID_COLUMN_TYPE, + false, + false, + Optional.empty()); + } + + private static String getNonDuplicatedColumnName(String baseColumnName, Set otherColumnNames) { // While it's unlikely this column name will collide with client table columns, // guarantee it will not by appending a deterministic suffix to it. - String baseColumnName = "trino_page_sink_id"; String columnName = baseColumnName; int suffix = 1; while (otherColumnNames.contains(columnName)) { columnName = baseColumnName + "_" + suffix; suffix++; } - return new MongoColumnHandle(columnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); + return columnName; } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java index ee35b8e292e4..929ed632353f 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java @@ -63,6 +63,8 @@ public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transaction handle.columns(), handle.updateCaseColumns(), handle.mergeRowIdColumn(), + handle.deleteOutputTableHandle(), + handle.updateOutputTableHandle(), implicitPrefix, handle.pageSinkIdColumnName(), pageSinkId); diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java index 327365ab3ffd..a1b36d34f0e8 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java @@ -82,22 +82,6 @@ protected void testDelete() .isCoordinatorOnly(); } - @Test - @Override - protected void testDeleteWithSubquery() - { - // TODO: solve https://github.com/trinodb/trino/issues/22256 - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); - } - - @Test - @Override - protected void testMerge() - { - // TODO: solve https://github.com/trinodb/trino/issues/22256 - assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support MERGE with retries"); - } - @Test @Override protected void testRefreshMaterializedView() @@ -107,22 +91,6 @@ protected void testRefreshMaterializedView() abort("skipped"); } - @Test - @Override - protected void testUpdate() - { - // TODO: solve https://github.com/trinodb/trino/issues/22256 - assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support MERGE with retries"); - } - - @Test - @Override - protected void testUpdateWithSubquery() - { - // TODO: solve https://github.com/trinodb/trino/issues/22256 - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); - } - @Override protected boolean areWriteRetriesSupported() {