Skip to content

Commit

Permalink
Add concurrent writes reconciliation for INSERT in Delta Lake
Browse files Browse the repository at this point in the history
  • Loading branch information
findinpath committed Aug 3, 2023
1 parent 43cee12 commit 661f4ba
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.deltalake;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.base.Throwables;
import com.google.common.base.VerifyException;
import com.google.common.collect.Comparators;
import com.google.common.collect.ImmutableList;
Expand All @@ -22,6 +23,8 @@
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import dev.failsafe.Failsafe;
import dev.failsafe.RetryPolicy;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -138,6 +141,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.Collection;
Expand All @@ -157,6 +161,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -187,6 +192,7 @@
import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.SYNTHESIZED;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_WRITE;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_FILESYSTEM_ERROR;
import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA;
Expand Down Expand Up @@ -241,6 +247,7 @@
import static io.trino.plugin.deltalake.transactionlog.MetadataEntry.configurationForNewTable;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.getMandatoryCurrentVersion;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.getTransactionLogDir;
import static io.trino.plugin.deltalake.transactionlog.checkpoint.TransactionLogTail.getEntriesFromJson;
import static io.trino.plugin.hive.HiveMetadata.PRESTO_QUERY_ID_NAME;
import static io.trino.plugin.hive.TableType.EXTERNAL_TABLE;
import static io.trino.plugin.hive.TableType.MANAGED_TABLE;
Expand All @@ -258,6 +265,7 @@
import static io.trino.spi.StandardErrorCode.INVALID_SCHEMA_PROPERTY;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.QUERY_REJECTED;
import static io.trino.spi.StandardErrorCode.TRANSACTION_CONFLICT;
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW;
import static io.trino.spi.connector.SchemaTableName.schemaTableName;
Expand Down Expand Up @@ -324,6 +332,12 @@ public class DeltaLakeMetadata
private static final int CDF_SUPPORTED_WRITER_VERSION = 4;
private static final int COLUMN_MAPPING_MODE_SUPPORTED_READER_VERSION = 2;
private static final int COLUMN_MAPPING_MODE_SUPPORTED_WRITER_VERSION = 5;
private static final RetryPolicy<Object> TRANSACTION_CONFLICT_RETRY_POLICY = RetryPolicy.builder()
.handleIf(throwable -> Throwables.getRootCause(throwable) instanceof TransactionConflictException)
.withDelay(Duration.ofMillis(200))
.withJitter(Duration.ofMillis(100))
.withMaxRetries(5)
.build();

// Matches the dummy column Databricks stores in the metastore
private static final List<Column> DUMMY_DATA_COLUMNS = ImmutableList.of(
Expand Down Expand Up @@ -1701,30 +1715,10 @@ public Optional<ConnectorOutputMetadata> finishInsert(

boolean writeCommitted = false;
try {
TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation());

long createdTime = Instant.now().toEpochMilli();

TrinoFileSystem fileSystem = fileSystemFactory.create(session);
long commitVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation()) + 1;
if (commitVersion != handle.getReadVersion() + 1) {
throw new TransactionConflictException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s",
handle.getReadVersion(),
commitVersion - 1));
}
Optional<Long> checkpointInterval = handle.getMetadataEntry().getCheckpointInterval();
// it is not obvious why we need to persist this readVersion
transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, handle.getReadVersion()));

ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry());
List<String> partitionColumns = getPartitionColumns(
handle.getMetadataEntry().getOriginalPartitionColumns(),
handle.getInputColumns(),
columnMappingMode);
appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, true);

transactionLogWriter.flush();
long commitVersion = Failsafe.with(TRANSACTION_CONFLICT_RETRY_POLICY)
.get(() -> getInsertCommitVersion(session, handle, dataFileInfos));
writeCommitted = true;
Optional<Long> checkpointInterval = handle.getMetadataEntry().getCheckpointInterval();
writeCheckpointIfNeeded(session, handle.getTableName(), handle.getLocation(), checkpointInterval, commitVersion);

if (isCollectExtendedStatisticsColumnStatisticsOnWrite(session) && !computedStatistics.isEmpty() && !dataFileInfos.isEmpty()) {
Expand Down Expand Up @@ -1755,6 +1749,67 @@ public Optional<ConnectorOutputMetadata> finishInsert(
return Optional.empty();
}

private long getInsertCommitVersion(ConnectorSession session, DeltaLakeInsertTableHandle handle, List<DataFileInfo> dataFileInfos)
throws IOException
{
long createdTime = Instant.now().toEpochMilli();

TrinoFileSystem fileSystem = fileSystemFactory.create(session);

String transactionLogDirectory = getTransactionLogDir(handle.getLocation());
long currentVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation());
if (currentVersion < handle.getReadVersion()) {
throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s",
handle.getReadVersion(),
currentVersion));
}
else if (currentVersion > handle.getReadVersion()) {
// Ensure there are no structural changes on the table if concurrent writes finished in the meantime
List<DeltaLakeTransactionLogEntry> transactionLogEntries = LongStream.rangeClosed(handle.getReadVersion() + 1, currentVersion)
.boxed()
.flatMap(version -> {
try {
return getEntriesFromJson(version, transactionLogDirectory, fileSystem)
.orElseThrow(() -> new TrinoException(DELTA_LAKE_BAD_DATA, "Delta Lake log entries are missing for version " + version))
.stream();
}
catch (IOException e) {
throw new TrinoException(DELTA_LAKE_FILESYSTEM_ERROR, "Failed to access table metadata", e);
}
})
.collect(toImmutableList());
Optional<MetadataEntry> currentMetadataEntry = transactionLogEntries.stream()
.map(DeltaLakeTransactionLogEntry::getMetaData)
.filter(Objects::nonNull)
.findFirst();
if (currentMetadataEntry.isPresent()) {
throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Metadata changed since the version: %s", handle.getReadVersion()));
}
Optional<ProtocolEntry> currentProtocolEntry = transactionLogEntries.stream()
.map(DeltaLakeTransactionLogEntry::getProtocol)
.filter(Objects::nonNull)
.findFirst();
if (currentProtocolEntry.isPresent()) {
throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Protocol changed since the version: %s", handle.getReadVersion()));
}
}
long commitVersion = currentVersion + 1;
// it is not obvious why we need to persist this readVersion
TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation());
transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, currentVersion));

ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry());
List<String> partitionColumns = getPartitionColumns(
handle.getMetadataEntry().getOriginalPartitionColumns(),
handle.getInputColumns(),
columnMappingMode);
appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, true);

transactionLogWriter.flush();

return commitVersion;
}

private static List<String> getPartitionColumns(List<String> originalPartitionColumns, List<DeltaLakeColumnHandle> dataColumns, ColumnMappingMode columnMappingMode)
{
return switch (columnMappingMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.concurrent.MoreFutures;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
Expand All @@ -32,6 +33,7 @@
import io.trino.spi.QueryId;
import io.trino.sql.planner.OptimizerConfig.JoinDistributionType;
import io.trino.testing.BaseConnectorSmokeTest;
import io.trino.testing.DataProviders;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.MaterializedResult;
import io.trino.testing.MaterializedResultWithQueryId;
Expand All @@ -49,6 +51,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.function.BiConsumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -80,6 +85,7 @@
import static io.trino.tpch.TpchTable.ORDERS;
import static java.lang.String.format;
import static java.util.Comparator.comparing;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -1324,7 +1330,8 @@ public Object[][] testCheckpointWriteStatsAsStructDataProvider()
{"varchar", "'test'", "'ŻŻŻŻŻŻŻŻŻŻ'", "0.0", "null", "null"},
{"varbinary", "X'65683F'", "X'ffffffffffffffffffff'", "0.0", "null", "null"},
{"date", "date '2021-02-03'", "date '9999-12-31'", "0.0", "'2021-02-03'", "'9999-12-31'"},
{"timestamp(3) with time zone", "timestamp '2001-08-22 03:04:05.321 -08:00'", "timestamp '9999-12-31 23:59:59.999 +12:00'", "0.0", "'2001-08-22 11:04:05.321 UTC'", "'9999-12-31 11:59:59.999 UTC'"},
{"timestamp(3) with time zone", "timestamp '2001-08-22 03:04:05.321 -08:00'", "timestamp '9999-12-31 23:59:59.999 +12:00'", "0.0", "'2001-08-22 11:04:05.321 UTC'",
"'9999-12-31 11:59:59.999 UTC'"},
{"array(int)", "array[1]", "array[2147483647]", "null", "null", "null"},
{"map(varchar,int)", "map(array['foo', 'bar'], array[1, 2])", "map(array['foo', 'bar'], array[-2147483648, 2147483647])", "null", "null", "null"},
{"row(x bigint)", "cast(row(1) as row(x bigint))", "cast(row(9223372036854775807) as row(x bigint))", "null", "null", "null"},
Expand Down Expand Up @@ -2151,6 +2158,55 @@ public void testPartitionFilterIncluded()
}
}

@Test(dataProviderClass = DataProviders.class, dataProvider = "trueFalse")
public void testConcurrentModificationsReconciliation(boolean partitioned)
throws Exception
{
int threads = 3;

CyclicBarrier barrier = new CyclicBarrier(threads);
ExecutorService executor = newFixedThreadPool(threads);
String tableName = "test_concurrent_inserts_table_" + randomNameSuffix();

assertUpdate("CREATE TABLE " + tableName + " (a INT, part INT) " +
(partitioned ? " WITH (partitioned_by = ARRAY['part'])" : ""));

try {
// insert data concurrently
executor.invokeAll(ImmutableList.<Callable<Void>>builder()
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (1, 10)");
return null;
})
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (11, 20)");
return null;
})
.add(() -> {
barrier.await(20, SECONDS);
getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (21, 30)");
return null;
})
.build())
.forEach(MoreFutures::getDone);

assertThat(query("SELECT SUM(a) FROM " + tableName)).matches("VALUES BIGINT '33'");
assertQuery("SELECT version, operation, read_version FROM \"" + tableName + "$history\"",
"""
VALUES
(0, 'CREATE TABLE', 0),
(1, 'WRITE', 0),
(2, 'WRITE', 1),
(3, 'WRITE', 2)
""");
}
finally {
assertUpdate("DROP TABLE " + tableName);
}
}

private Set<String> getActiveFiles(String tableName)
{
return getActiveFiles(tableName, getQueryRunner().getDefaultSession());
Expand Down

0 comments on commit 661f4ba

Please sign in to comment.