Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict insert overwrite to autocommit #9675

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/trino-main/src/main/java/io/trino/FullConnectorSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class FullConnectorSession
implements ConnectorSession
{
private final Session session;
private final Optional<Boolean> transactionAutoCommitContext;
private final ConnectorIdentity identity;
private final Map<String, String> properties;
private final CatalogName catalogName;
Expand All @@ -44,6 +45,7 @@ public class FullConnectorSession
public FullConnectorSession(Session session, ConnectorIdentity identity)
{
this.session = requireNonNull(session, "session is null");
this.transactionAutoCommitContext = Optional.empty();
this.identity = requireNonNull(identity, "identity is null");
this.properties = null;
this.catalogName = null;
Expand All @@ -53,13 +55,15 @@ public FullConnectorSession(Session session, ConnectorIdentity identity)

public FullConnectorSession(
Session session,
Optional<Boolean> transactionAutoCommitContext,
ConnectorIdentity identity,
Map<String, String> properties,
CatalogName catalogName,
String catalog,
SessionPropertyManager sessionPropertyManager)
{
this.session = requireNonNull(session, "session is null");
this.transactionAutoCommitContext = requireNonNull(transactionAutoCommitContext, "transactionAutoCommitContext is null");
this.identity = requireNonNull(identity, "identity is null");
this.properties = ImmutableMap.copyOf(requireNonNull(properties, "properties is null"));
this.catalogName = requireNonNull(catalogName, "catalogName is null");
Expand All @@ -72,6 +76,12 @@ public Session getSession()
return session;
}

@Override
public boolean isAutoCommitContext()
{
return transactionAutoCommitContext.orElseThrow(NotInTransactionException::new);
}

@Override
public String getQueryId()
{
Expand Down
20 changes: 15 additions & 5 deletions core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public final class Session
private final SessionPropertyManager sessionPropertyManager;
private final Map<String, String> preparedStatements;
private final ProtocolHeaders protocolHeaders;
private final Optional<Boolean> transactionAutoCommitContext;

public Session(
QueryId queryId,
Expand All @@ -110,7 +111,8 @@ public Session(
Map<String, Map<String, String>> unprocessedCatalogProperties,
SessionPropertyManager sessionPropertyManager,
Map<String, String> preparedStatements,
ProtocolHeaders protocolHeaders)
ProtocolHeaders protocolHeaders,
Optional<Boolean> transactionAutoCommitContext)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.transactionId = requireNonNull(transactionId, "transactionId is null");
Expand Down Expand Up @@ -150,6 +152,8 @@ public Session(
checkArgument(transactionId.isEmpty() || unprocessedCatalogProperties.isEmpty(), "Catalog session properties cannot be set if there is an open transaction");

checkArgument(catalog.isPresent() || schema.isEmpty(), "schema is set but catalog is not");

this.transactionAutoCommitContext = requireNonNull(transactionAutoCommitContext, "transactionId is null");
}

public QueryId getQueryId()
Expand Down Expand Up @@ -359,6 +363,7 @@ public Session beginTransactionId(TransactionId transactionId, TransactionManage
connectorRoles.put(systemTablesCatalogName, role);
}
}
boolean isAutoCommitContext = transactionManager.getTransactionInfo(transactionId).isAutoCommitContext();

return new Session(
queryId,
Expand Down Expand Up @@ -386,7 +391,8 @@ public Session beginTransactionId(TransactionId transactionId, TransactionManage
ImmutableMap.of(),
sessionPropertyManager,
preparedStatements,
protocolHeaders);
protocolHeaders,
Optional.of(isAutoCommitContext));
}

public Session withDefaultProperties(Map<String, String> systemPropertyDefaults, Map<String, Map<String, String>> catalogPropertyDefaults)
Expand Down Expand Up @@ -438,7 +444,8 @@ public Session withDefaultProperties(Map<String, String> systemPropertyDefaults,
connectorProperties,
sessionPropertyManager,
preparedStatements,
protocolHeaders);
protocolHeaders,
transactionAutoCommitContext);
}

public ConnectorSession toConnectorSession()
Expand All @@ -457,6 +464,7 @@ public ConnectorSession toConnectorSession(CatalogName catalogName)

return new FullConnectorSession(
this,
transactionAutoCommitContext,
identity.toConnectorIdentity(catalogName.getCatalogName()),
connectorProperties.getOrDefault(catalogName, ImmutableMap.of()),
catalogName,
Expand Down Expand Up @@ -493,7 +501,8 @@ public SessionRepresentation toSessionRepresentation()
unprocessedCatalogProperties,
identity.getCatalogRoles(),
preparedStatements,
protocolHeaders.getProtocolName());
protocolHeaders.getProtocolName(),
transactionAutoCommitContext);
}

@Override
Expand Down Expand Up @@ -824,7 +833,8 @@ public Session build()
catalogSessionProperties,
sessionPropertyManager,
preparedStatements,
protocolHeaders);
protocolHeaders,
Optional.empty());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public final class SessionRepresentation
private final Map<String, SelectedRole> catalogRoles;
private final Map<String, String> preparedStatements;
private final String protocolName;
private final Optional<Boolean> transactionAutoCommitContext;

@JsonCreator
public SessionRepresentation(
Expand Down Expand Up @@ -97,7 +98,8 @@ public SessionRepresentation(
@JsonProperty("unprocessedCatalogProperties") Map<String, Map<String, String>> unprocessedCatalogProperties,
@JsonProperty("catalogRoles") Map<String, SelectedRole> catalogRoles,
@JsonProperty("preparedStatements") Map<String, String> preparedStatements,
@JsonProperty("protocolName") String protocolName)
@JsonProperty("protocolName") String protocolName,
@JsonProperty("transactionAutoCommitContext") Optional<Boolean> transactionAutoCommitContext)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.transactionId = requireNonNull(transactionId, "transactionId is null");
Expand Down Expand Up @@ -136,6 +138,7 @@ public SessionRepresentation(
unprocessedCatalogPropertiesBuilder.put(entry.getKey(), ImmutableMap.copyOf(entry.getValue()));
}
this.unprocessedCatalogProperties = unprocessedCatalogPropertiesBuilder.build();
this.transactionAutoCommitContext = requireNonNull(transactionAutoCommitContext, "transactionAutoCommitContext is null");
}

@JsonProperty
Expand Down Expand Up @@ -343,6 +346,7 @@ public Session toSession(SessionPropertyManager sessionPropertyManager, Map<Stri
unprocessedCatalogProperties,
sessionPropertyManager,
preparedStatements,
createProtocolHeaders(protocolName));
createProtocolHeaders(protocolName),
transactionAutoCommitContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
import io.trino.sql.tree.Statement;
import io.trino.testing.PageConsumerOperator.PageConsumerOutputFactory;
import io.trino.transaction.InMemoryTransactionManager;
import io.trino.transaction.TransactionId;
import io.trino.transaction.TransactionManager;
import io.trino.transaction.TransactionManagerConfig;
import io.trino.type.BlockTypeOperators;
Expand Down Expand Up @@ -424,9 +425,10 @@ private LocalQueryRunner(
connectorManager.createCatalog(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of());

// rewrite session to use managed SessionPropertyMetadata
Optional<TransactionId> transactionId = withInitialTransaction ? Optional.of(transactionManager.beginTransaction(false)) : defaultSession.getTransactionId();
this.defaultSession = new Session(
defaultSession.getQueryId(),
withInitialTransaction ? Optional.of(transactionManager.beginTransaction(false)) : defaultSession.getTransactionId(),
transactionId,
defaultSession.isClientTransactionSupport(),
defaultSession.getIdentity(),
defaultSession.getSource(),
Expand All @@ -448,7 +450,8 @@ private LocalQueryRunner(
defaultSession.getUnprocessedCatalogProperties(),
metadata.getSessionPropertyManager(),
defaultSession.getPreparedStatements(),
defaultSession.getProtocolHeaders());
defaultSession.getProtocolHeaders(),
transactionId.map(tId -> transactionManager.getTransactionInfo(tId).isAutoCommitContext()));

dataDefinitionTask = ImmutableMap.<Class<? extends Statement>, DataDefinitionTask<?>>builder()
.put(CreateTable.class, new CreateTableTask(featuresConfig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ public String getQueryId()
return queryId;
}

@Override
public boolean isAutoCommitContext()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

public interface ConnectorSession
{
boolean isAutoCommitContext();

String getQueryId();

Optional<String> getSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ public String getQueryId()
return "to_string";
}

@Override
public boolean isAutoCommitContext()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ public String getQueryId()
return "test_query_id";
}

@Override
public boolean isAutoCommitContext()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1675,9 +1675,15 @@ public HiveInsertTableHandle beginInsert(ConnectorSession session, ConnectorTabl

WriteInfo writeInfo = locationService.getQueryWriteInfo(locationHandle);
if (getInsertExistingPartitionsBehavior(session) == InsertExistingPartitionsBehavior.OVERWRITE
&& isTransactionalTable(table.getParameters())
&& writeInfo.getWriteMode() == DIRECT_TO_TARGET_EXISTING_DIRECTORY) {
throw new TrinoException(NOT_SUPPORTED, "Overwriting existing partition in transactional tables doesn't support DIRECT_TO_TARGET_EXISTING_DIRECTORY write mode");
if (isTransactionalTable(table.getParameters())) {
throw new TrinoException(NOT_SUPPORTED, "Overwriting existing partition in transactional tables doesn't support DIRECT_TO_TARGET_EXISTING_DIRECTORY write mode");
}
// This check is required to prevent using partition overwrite operation during user managed transactions
// Partition overwrite operation is nonatomic thus can't and shouldn't be used in non autocommit context.
if (!session.isAutoCommitContext()) {
throw new TrinoException(NOT_SUPPORTED, "Overwriting existing partition in non auto commit context doesn't support DIRECT_TO_TARGET_EXISTING_DIRECTORY write mode");
aczajkowski marked this conversation as resolved.
Show resolved Hide resolved
}
}
metastore.declareIntentionToWrite(session, writeInfo.getWriteMode(), writeInfo.getWritePath(), tableName);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.hive;

import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.plugin.hive.containers.HiveMinioDataLake;
import io.trino.plugin.hive.s3.S3HiveQueryRunner;
import io.trino.testing.AbstractTestQueryFramework;
Expand All @@ -29,6 +30,7 @@
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class BaseTestHiveInsertOverwrite
extends AbstractTestQueryFramework
Expand Down Expand Up @@ -75,6 +77,20 @@ public void setUp()
bucketName));
}

@Test
public void testInsertOverwriteInTransaction()
{
String testTable = getTestTableName();
computeActual(getCreateTableStatement(testTable, "partitioned_by=ARRAY['regionkey']"));
assertThatThrownBy(
() -> newTransaction()
.execute(getSession(), session -> {
getQueryRunner().execute(session, createInsertStatement(testTable));
}))
.hasMessage("Overwriting existing partition in non auto commit context doesn't support DIRECT_TO_TARGET_EXISTING_DIRECTORY write mode");
computeActual(format("DROP TABLE %s", testTable));
}

@Test
public void testInsertOverwriteNonPartitionedTable()
{
Expand Down Expand Up @@ -149,15 +165,26 @@ public void testInsertOverwritePartitionedAndBucketedExternalTable()
}

protected void assertInsertFailure(String testTable, String expectedMessageRegExp)
{
assertInsertFailure(getSession(), testTable, expectedMessageRegExp);
}

protected void assertInsertFailure(Session session, String testTable, String expectedMessageRegExp)
{
assertQueryFails(
format("INSERT INTO %s " +
"SELECT name, comment, nationkey, regionkey " +
"FROM tpch.tiny.nation",
testTable),
session,
createInsertStatement(testTable),
expectedMessageRegExp);
}

private String createInsertStatement(String testTable)
{
return format("INSERT INTO %s " +
"SELECT name, comment, nationkey, regionkey " +
"FROM tpch.tiny.nation",
testTable);
}

protected void assertOverwritePartition(String testTable)
{
computeActual(format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5219,7 +5219,8 @@ public void testShowSession()
.build()),
getQueryRunner().getMetadata().getSessionPropertyManager(),
getSession().getPreparedStatements(),
getSession().getProtocolHeaders());
getSession().getProtocolHeaders(),
Optional.empty());
MaterializedResult result = computeActual(session, "SHOW SESSION");

ImmutableMap<String, MaterializedRow> properties = Maps.uniqueIndex(result.getMaterializedRows(), input -> {
Expand Down