diff --git a/docs/src/main/sphinx/connector/mongodb.md b/docs/src/main/sphinx/connector/mongodb.md index ebdd4e219991..83306597e228 100644 --- a/docs/src/main/sphinx/connector/mongodb.md +++ b/docs/src/main/sphinx/connector/mongodb.md @@ -476,6 +476,8 @@ statements, the connector supports the following features: - {doc}`/sql/insert` - {doc}`/sql/delete` +- {doc}`/sql/update` +- {doc}`/sql/merge` - {doc}`/sql/create-table` - {doc}`/sql/create-table-as` - {doc}`/sql/drop-table` diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java new file mode 100644 index 000000000000..e96998aa5e1c --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java @@ -0,0 +1,309 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.BulkWriteOptions; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.WriteModel; +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSinkId; +import org.bson.Document; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.mongodb.client.model.Filters.in; +import static io.trino.plugin.mongodb.MongoMetadata.MERGE_ROW_ID_BASE_NAME; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.util.Objects.requireNonNull; +import static org.weakref.jmx.$internal.guava.collect.ImmutableList.toImmutableList; + +public class MongoMergeSink + implements ConnectorMergeSink +{ + private static final String SET = "$set"; + + private final int columnCount; + + private final ConnectorPageSink insertSink; + private final ConnectorPageSink deleteSink; + + private final Map> updateSinkSuppliers; + private final Map updateCaseChannels; + + public MongoMergeSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + Map> updateCaseColumns, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + requireNonNull(mongoSession, "mongoSession is null"); + requireNonNull(remoteTableName, "remoteTableName is null"); + requireNonNull(columns, "columns is null"); + requireNonNull(updateCaseColumns, "updateCaseColumns is null"); + requireNonNull(pageSinkId, "pageSinkId is null"); + requireNonNull(mergeColumnHandle, "mergeColumnHandle is null"); + + this.columnCount = columns.size(); + + this.insertSink = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.deleteSink = new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId); + + ImmutableMap.Builder> updateSinksBuilder = ImmutableMap.builder(); + ImmutableMap.Builder updateCaseChannelsBuilder = ImmutableMap.builder(); + for (Map.Entry> entry : updateCaseColumns.entrySet()) { + int caseNumber = entry.getKey(); + List updateColumns = entry.getValue().stream() + .map(MongoColumnHandle.class::cast) + .collect(toImmutableList()); + Set columnChannels = updateColumns.stream() + .map(columns::indexOf) + .collect(toImmutableSet()); + Supplier updateSupplier = Suppliers.memoize(() -> createUpdateSink( + mongoSession, + remoteTableName, + updateColumns, + mergeColumnHandle, + implicitPrefix, + pageSinkIdColumnName, + pageSinkId)); + updateSinksBuilder.put(caseNumber, updateSupplier); + updateCaseChannelsBuilder.put(caseNumber, columnChannels.stream().mapToInt(Integer::intValue).sorted().toArray()); + } + this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow(); + this.updateCaseChannels = updateCaseChannelsBuilder.buildOrThrow(); + } + + @Override + public void storeMergedRows(Page page) + { + checkArgument(page.getChannelCount() == 3 + columnCount, "The page size should be 3 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + int positionCount = page.getPositionCount(); + Block operationBlock = page.getBlock(columnCount); + Block rowId = page.getBlock(columnCount + 2); + + int[] dataChannel = IntStream.range(0, columnCount + 1).toArray(); + dataChannel[columnCount] = columnCount + 2; + Page dataPage = page.getColumns(dataChannel); + + int[] insertPositions = new int[positionCount]; + int insertPositionCount = 0; + int[] deletePositions = new int[positionCount]; + int deletePositionCount = 0; + + Block updateCaseBlock = page.getBlock(columnCount + 1); + Map updatePositions = new HashMap<>(); + Map updatePositionCounts = new HashMap<>(); + + for (int position = 0; position < positionCount; position++) { + int operation = TINYINT.getByte(operationBlock, position); + switch (operation) { + case INSERT_OPERATION_NUMBER -> { + insertPositions[insertPositionCount] = position; + insertPositionCount++; + } + case DELETE_OPERATION_NUMBER -> { + deletePositions[deletePositionCount] = position; + deletePositionCount++; + } + case UPDATE_OPERATION_NUMBER -> { + int caseNumber = INTEGER.getInt(updateCaseBlock, position); + int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0); + updatePositions.computeIfAbsent(caseNumber, _ -> new int[positionCount])[updatePositionCount] = position; + updatePositionCounts.put(caseNumber, updatePositionCount + 1); + } + default -> throw new IllegalStateException("Unexpected value: " + operation); + } + } + + if (deletePositionCount > 0) { + Block positions = rowId.getPositions(deletePositions, 0, deletePositionCount); + deleteSink.appendPage(new Page(deletePositionCount, positions)); + } + + for (Map.Entry entry : updatePositionCounts.entrySet()) { + int caseNumber = entry.getKey(); + int updatePositionCount = entry.getValue(); + if (updatePositionCount > 0) { + checkArgument(updatePositions.containsKey(caseNumber), "Unexpected case number %s", caseNumber); + int[] positions = updatePositions.get(caseNumber); + int[] updateAssignmentChannels = updateCaseChannels.get(caseNumber); + Block[] updateBlocks = new Block[updateAssignmentChannels.length + 1]; + for (int channel = 0; channel < updateAssignmentChannels.length; channel++) { + updateBlocks[channel] = dataPage.getBlock(updateAssignmentChannels[channel]).getPositions(positions, 0, updatePositionCount); + } + updateBlocks[updateAssignmentChannels.length] = rowId.getPositions(positions, 0, updatePositionCount); + + updateSinkSuppliers.get(caseNumber).get().appendPage(new Page(updatePositionCount, updateBlocks)); + } + } + + if (insertPositionCount > 0) { + // Insert page should not include _id column by default, unless the insert columns include it explicitly + Page insertPage = dataPage.getColumns(Arrays.copyOf(dataChannel, columnCount)); + insertSink.appendPage(insertPage.getPositions(insertPositions, 0, insertPositionCount)); + } + } + + @Override + public CompletableFuture> finish() + { + CompletableFuture> finish = insertSink.finish(); + deleteSink.finish(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish); + + return finish; + } + + @Override + public void abort() + { + insertSink.abort(); + deleteSink.abort(); + updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort); + } + + private static ConnectorPageSink createUpdateSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + Collection columns, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + // Update should always include id column explicitly + List updateColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) + .addAll(columns) + .add(mergeColumnHandle) + .build(); + return new MongoUpdateSink(mongoSession, remoteTableName, updateColumns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + private static class MongoUpdateSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoUpdateSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + List> bulkWrites = new ArrayList<>(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + Document filter = new Document(MERGE_ROW_ID_BASE_NAME, document.get(MERGE_ROW_ID_BASE_NAME)); + bulkWrites.add(new UpdateOneModel<>(filter, new Document(SET, document))); + } + collection.bulkWrite(bulkWrites, new BulkWriteOptions().ordered(false)); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } + + private static class MongoDeleteSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoDeleteSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + ImmutableList.Builder idsToDeleteBuilder = ImmutableList.builder(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + idsToDeleteBuilder.add(document.get(MERGE_ROW_ID_BASE_NAME)); + } + collection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDeleteBuilder.build())); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java new file mode 100644 index 000000000000..ddd08c08ae14 --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public record MongoMergeTableHandle( + RemoteTableName remoteTableName, + List columns, + Map> updateCaseColumns, + MongoColumnHandle mergeRowIdColumn, + Optional filter, + TupleDomain constraint, + Optional temporaryTableName, + Optional pageSinkIdColumnName) + implements ConnectorMergeTableHandle +{ + public MongoMergeTableHandle + { + requireNonNull(remoteTableName, "remoteTableName is null"); + columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + updateCaseColumns = ImmutableMap.copyOf(requireNonNull(updateCaseColumns, "updateCaseColumns is null")); + requireNonNull(filter, "filter is null"); + requireNonNull(constraint, "constraint is null"); + requireNonNull(temporaryTableName, "temporaryTableName is null"); + requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null"); + checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(), + "temporaryTableName.isPresent is not equal to pageSinkIdColumnName.isPresent"); + } + + @JsonIgnore + public Optional getTemporaryRemoteTableName() + { + return temporaryTableName.map(tableName -> new RemoteTableName(remoteTableName.databaseName(), tableName)); + } + + @Override + public ConnectorTableHandle getTableHandle() + { + return new MongoTableHandle( + new SchemaTableName(remoteTableName.databaseName(), remoteTableName.collectionName()), + remoteTableName, + filter, + constraint, + ImmutableSet.of(), + OptionalInt.empty()); + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index e023e61e7211..cf7a72e782da 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -29,6 +29,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -46,6 +47,7 @@ import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RelationColumnsMetadata; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SaveMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -107,6 +109,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RelationColumnsMetadata.forTable; +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; import static io.trino.spi.connector.SaveMode.REPLACE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -126,6 +129,8 @@ public class MongoMetadata implements ConnectorMetadata { + public static final String MERGE_ROW_ID_BASE_NAME = "_id"; + private static final Logger log = Logger.get(MongoMetadata.class); private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT; @@ -544,10 +549,56 @@ private void finishInsert( } } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateCaseColumns, RetryMode retryMode) + { + if (retryMode != RetryMode.NO_RETRIES) { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support MERGE with retries"); + } + + MongoTableHandle table = (MongoTableHandle) tableHandle; + MongoInsertTableHandle insertTableHandle = (MongoInsertTableHandle) beginInsert(session, tableHandle, ImmutableList.of(), retryMode); + return new MongoMergeTableHandle( + insertTableHandle.remoteTableName(), + insertTableHandle.columns(), + updateCaseColumns, + (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle), + table.filter(), + table.constraint(), + insertTableHandle.temporaryTableName(), + insertTableHandle.pageSinkIdColumnName()); + } + + @Override + public void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle mergeTableHandle, + List sourceTableHandles, + Collection fragments, + Collection computedStatistics) + { + MongoMergeTableHandle tableHandle = (MongoMergeTableHandle) mergeTableHandle; + MongoInsertTableHandle insertTableHandle = new MongoInsertTableHandle( + tableHandle.remoteTableName(), + ImmutableList.copyOf(tableHandle.columns()), + tableHandle.temporaryTableName(), + tableHandle.pageSinkIdColumnName()); + finishInsert(session, insertTableHandle, ImmutableList.of(), fragments, computedStatistics); + } + @Override public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty()); + Map columnHandles = getColumnHandles(session, tableHandle); + checkState(columnHandles.containsKey(MERGE_ROW_ID_BASE_NAME), "id column %s not exists", MERGE_ROW_ID_BASE_NAME); + Type idColumnType = ((MongoColumnHandle) columnHandles.get(MERGE_ROW_ID_BASE_NAME)).type(); + return new MongoColumnHandle(MERGE_ROW_ID_BASE_NAME, ImmutableList.of(), idColumnType, true, false, Optional.empty()); } @Override diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index 06496aac7b09..7d21e426c2b8 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -110,6 +110,12 @@ public MongoPageSink( public CompletableFuture appendPage(Page page) { MongoCollection collection = mongoSession.getCollection(remoteTableName); + collection.insertMany(buildBatchDocumentsFromPage(page), new InsertManyOptions().ordered(true)); + return NOT_BLOCKED; + } + + protected List buildBatchDocumentsFromPage(Page page) + { List batch = new ArrayList<>(page.getPositionCount()); for (int position = 0; position < page.getPositionCount(); position++) { @@ -122,9 +128,7 @@ public CompletableFuture appendPage(Page page) } batch.add(doc); } - - collection.insertMany(batch, new InsertManyOptions().ordered(true)); - return NOT_BLOCKED; + return batch; } private Object getObjectValue(Type type, Block block, int position) diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java index 7ebf8667b4de..ee35b8e292e4 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java @@ -15,6 +15,8 @@ import com.google.inject.Inject; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkId; @@ -50,4 +52,19 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa MongoInsertTableHandle handle = (MongoInsertTableHandle) insertTableHandle; return new MongoPageSink(mongoSession, handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), handle.columns(), implicitPrefix, handle.pageSinkIdColumnName(), pageSinkId); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle, ConnectorPageSinkId pageSinkId) + { + MongoMergeTableHandle handle = (MongoMergeTableHandle) mergeHandle; + return new MongoMergeSink( + mongoSession, + handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), + handle.columns(), + handle.updateCaseColumns(), + handle.mergeRowIdColumn(), + implicitPrefix, + handle.pageSinkIdColumnName(), + pageSinkId); + } } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java index f39404a42ea0..27ad160ab8fc 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java @@ -30,11 +30,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return switch (connectorBehavior) { case SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_SCHEMA, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java index 8b1d9ca99d1b..327365ab3ffd 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assumptions.abort; @@ -69,24 +70,32 @@ protected void testAnalyzeTable() @Override protected void testDelete() { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // This simple delete on Mongo ends up as a very simple, single-fragment, coordinator-only plan, + // which has no ability to recover from errors. This test simply verifies that's still the case. + Optional setupQuery = Optional.of("CREATE TABLE AS SELECT * FROM orders"); + String testQuery = "DELETE FROM
WHERE orderkey = 1"; + Optional cleanupQuery = Optional.of("DROP TABLE
"); + + assertThatQuery(testQuery) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .isCoordinatorOnly(); } @Test @Override protected void testDeleteWithSubquery() { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @Override protected void testMerge() { - assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @@ -102,16 +111,16 @@ protected void testRefreshMaterializedView() @Override protected void testUpdate() { - assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support MERGE with retries"); } @Test @Override protected void testUpdateWithSubquery() { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + // TODO: solve https://github.com/trinodb/trino/issues/22256 + assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support MERGE with retries"); } @Override diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index fd85c4abb8a6..a90a4da19965 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -53,7 +53,6 @@ import static com.mongodb.client.model.CollationStrength.PRIMARY; import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoClient; import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType; -import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; @@ -103,13 +102,11 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, SUPPORTS_DROP_FIELD, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_FIELD, SUPPORTS_RENAME_SCHEMA, SUPPORTS_SET_FIELD_TYPE, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } @@ -120,6 +117,29 @@ protected TestTable createTableWithDefaultColumns() return abort("MongoDB connector does not support column default values"); } + @Test + public void testMergeWithCustomizeTypeIdColumn() + { + String targetTable = "merge_with_customize_type_id_column_target_" + randomNameSuffix(); + String sourceTable = "merget_with_customize_type_id_column_source_" + randomNameSuffix(); + createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.empty()); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 5, 'Antioch'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 3, 'Cambridge'), (4, 'Dave', 11, 'Devon')".formatted(targetTable), 4); + + createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 6, 'Arches'), (3, 'Ed', 7, 'Etherville'), (4, 'Carol', 9, 'Centreville'), (7, 'Dave', 11, 'Darbyshire')".formatted(sourceTable), 4); + + assertUpdate("MERGE INTO %s t USING %s s ON (t._id = s._id)".formatted(targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (_id, customer, purchases, address) VALUES(s._id, s.customer, s.purchases, s.address)", 4); + assertQuery("SELECT _id, customer, purchases, address FROM " + targetTable, "VALUES (1, 'Aaron', 11, 'Arches'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 10, 'Etherville'), (7, 'Dave', 11, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + @Test @Override public void testColumnName() @@ -295,46 +315,6 @@ public void testInsertWithEveryType() assertThat(getQueryRunner().tableExists(getSession(), tableName)).isFalse(); } - @Test - @Override - public void testDeleteWithComplexPredicate() - { - assertThatThrownBy(super::testDeleteWithComplexPredicate) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithLike() - { - assertThatThrownBy(super::testDeleteWithLike) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSemiJoin() - { - assertThatThrownBy(super::testDeleteWithSemiJoin) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSubquery() - { - assertThatThrownBy(super::testDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testExplainAnalyzeWithDeleteWithSubquery() - { - assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - @Test public void testPredicatePushdown() {