Skip to content

Commit

Permalink
refresh cached tables; split multiple partition specs into separate m…
Browse files Browse the repository at this point in the history
…anifest files
  • Loading branch information
ahmedabu98 committed Oct 21, 2024
1 parent 3ee46c6 commit baba789
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.sdk.io.iceberg;

import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.metrics.Counter;
Expand All @@ -29,14 +32,23 @@
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.iceberg.AppendFiles;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.ManifestFile;
import org.apache.iceberg.ManifestFiles;
import org.apache.iceberg.ManifestWriter;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.catalog.Catalog;
import org.apache.iceberg.catalog.TableIdentifier;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.OutputFile;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -45,9 +57,11 @@ class AppendFilesToTables
extends PTransform<PCollection<FileWriteResult>, PCollection<KV<String, SnapshotInfo>>> {
private static final Logger LOG = LoggerFactory.getLogger(AppendFilesToTables.class);
private final IcebergCatalogConfig catalogConfig;
private final String manifestFilePrefix;

AppendFilesToTables(IcebergCatalogConfig catalogConfig) {
AppendFilesToTables(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) {
this.catalogConfig = catalogConfig;
this.manifestFilePrefix = manifestFilePrefix;
}

@Override
Expand All @@ -67,27 +81,27 @@ public String apply(FileWriteResult input) {
.apply("Group metadata updates by table", GroupByKey.create())
.apply(
"Append metadata updates to tables",
ParDo.of(new AppendFilesToTablesDoFn(catalogConfig)))
ParDo.of(new AppendFilesToTablesDoFn(catalogConfig, manifestFilePrefix)))
.setCoder(KvCoder.of(StringUtf8Coder.of(), SnapshotInfo.CODER));
}

private static class AppendFilesToTablesDoFn
extends DoFn<KV<String, Iterable<FileWriteResult>>, KV<String, SnapshotInfo>> {
private final Counter snapshotsCreated =
Metrics.counter(AppendFilesToTables.class, "snapshotsCreated");
private final Counter dataFilesCommitted =
Metrics.counter(AppendFilesToTables.class, "dataFilesCommitted");
private final Distribution committedDataFileByteSize =
Metrics.distribution(RecordWriter.class, "committedDataFileByteSize");
private final Distribution committedDataFileRecordCount =
Metrics.distribution(RecordWriter.class, "committedDataFileRecordCount");

private final IcebergCatalogConfig catalogConfig;
private final String manifestFilePrefix;

private transient @MonotonicNonNull Catalog catalog;

private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig) {
private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) {
this.catalogConfig = catalogConfig;
this.manifestFilePrefix = manifestFilePrefix;
}

private Catalog getCatalog() {
Expand All @@ -97,36 +111,103 @@ private Catalog getCatalog() {
return catalog;
}

private boolean containsMultiplePartitionSpecs(Iterable<FileWriteResult> fileWriteResults) {
int id = fileWriteResults.iterator().next().getSerializableDataFile().getPartitionSpecId();
for (FileWriteResult result : fileWriteResults) {
if (id != result.getSerializableDataFile().getPartitionSpecId()) {
return true;
}
}
return false;
}

@ProcessElement
public void processElement(
@Element KV<String, Iterable<FileWriteResult>> element,
OutputReceiver<KV<String, SnapshotInfo>> out,
BoundedWindow window) {
BoundedWindow window)
throws IOException {
String tableStringIdentifier = element.getKey();
Iterable<FileWriteResult> fileWriteResults = element.getValue();
if (!fileWriteResults.iterator().hasNext()) {
return;
}

Table table = getCatalog().loadTable(TableIdentifier.parse(element.getKey()));

// vast majority of the time, we will simply append data files.
// in the rare case we get a batch that contains multiple partition specs, we will group
// data into manifest files and append.
// note: either way, we must use a single commit operation for atomicity.
if (containsMultiplePartitionSpecs(fileWriteResults)) {
appendManifestFiles(table, fileWriteResults);
} else {
appendDataFiles(table, fileWriteResults);
}

Snapshot snapshot = table.currentSnapshot();
LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot);
snapshotsCreated.inc();
out.outputWithTimestamp(
KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp());
}

// This works only when all files are using the same partition spec.
private void appendDataFiles(Table table, Iterable<FileWriteResult> fileWriteResults) {
AppendFiles update = table.newAppend();
long numFiles = 0;
for (FileWriteResult result : fileWriteResults) {
DataFile dataFile = result.getDataFile(table.specs());
update.appendFile(dataFile);
committedDataFileByteSize.update(dataFile.fileSizeInBytes());
committedDataFileRecordCount.update(dataFile.recordCount());
numFiles++;
}
// this commit will create a ManifestFile. we don't need to manually create one.
update.commit();
dataFilesCommitted.inc(numFiles);
}

// When a user updates their table partition spec during runtime, we can end up with
// a batch of files where some are written with the old spec and some are written with the new
// spec.
// A table commit is limited to a single partition spec.
// To handle this, we create a manifest file for each partition spec, and group data files
// accordingly.
// Afterward, we append all manifests using a single commit operation.
private void appendManifestFiles(Table table, Iterable<FileWriteResult> fileWriteResults)
throws IOException {
String uuid = UUID.randomUUID().toString();
Map<Integer, PartitionSpec> specs = table.specs();
Map<Integer, ManifestWriter<DataFile>> manifestFileWriters = Maps.newHashMap();
// first add datafiles to the appropriate manifest file, according to its spec id
for (FileWriteResult result : fileWriteResults) {
DataFile dataFile = result.getDataFile(specs);
int specId = dataFile.specId();
PartitionSpec spec = Preconditions.checkStateNotNull(specs.get(specId));
ManifestWriter<DataFile> writer =
manifestFileWriters.computeIfAbsent(
specId, id -> createManifestWriter(table.location(), uuid, spec, table.io()));
writer.add(dataFile);
committedDataFileByteSize.update(dataFile.fileSizeInBytes());
committedDataFileRecordCount.update(dataFile.recordCount());
}

// append all manifest files and commit
AppendFiles update = table.newAppend();
for (ManifestWriter<DataFile> writer : manifestFileWriters.values()) {
writer.close();
ManifestFile manifestFile = writer.toManifestFile();
update.appendManifest(manifestFile);
}
update.commit();
}

Snapshot snapshot = table.currentSnapshot();
LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot);
snapshotsCreated.inc();
out.outputWithTimestamp(
KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp());
private ManifestWriter<DataFile> createManifestWriter(
String tableLocation, String uuid, PartitionSpec spec, FileIO io) {
String location =
FileFormat.AVRO.addExtension(
String.format(
"%s/metadata/%s-%s-%s.manifest",
tableLocation, manifestFilePrefix, uuid, spec.specId()));
OutputFile outputFile = io.newOutputFile(location);
return ManifestFiles.write(spec, outputFile);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.util.Preconditions;
Expand Down Expand Up @@ -195,7 +194,9 @@ private RecordWriter createWriter(PartitionKey partitionKey) {

private final Map<WindowedValue<IcebergDestination>, List<SerializableDataFile>>
totalSerializableDataFiles = Maps.newHashMap();
private static final Cache<TableIdentifier, Table> TABLE_CACHE =

@VisibleForTesting
static final Cache<TableIdentifier, Table> TABLE_CACHE =
CacheBuilder.newBuilder().expireAfterAccess(10, TimeUnit.MINUTES).build();

private boolean isClosed = false;
Expand All @@ -221,22 +222,28 @@ private RecordWriter createWriter(PartitionKey partitionKey) {
private Table getOrCreateTable(TableIdentifier identifier, Schema dataSchema) {
@Nullable Table table = TABLE_CACHE.getIfPresent(identifier);
if (table == null) {
try {
table = catalog.loadTable(identifier);
} catch (NoSuchTableException e) {
synchronized (TABLE_CACHE) {
try {
org.apache.iceberg.Schema tableSchema =
IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
// TODO(ahmedabu98): support creating a table with a specified partition spec
table = catalog.createTable(identifier, tableSchema);
LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema);
} catch (AlreadyExistsException alreadyExistsException) {
// handle race condition where workers are concurrently creating the same table.
// if running into already exists exception, we perform one last load
table = catalog.loadTable(identifier);
} catch (NoSuchTableException e) {
try {
org.apache.iceberg.Schema tableSchema =
IcebergUtils.beamSchemaToIcebergSchema(dataSchema);
// TODO(ahmedabu98): support creating a table with a specified partition spec
table = catalog.createTable(identifier, tableSchema);
LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema);
} catch (AlreadyExistsException alreadyExistsException) {
// handle race condition where workers are concurrently creating the same table.
// if running into already exists exception, we perform one last load
table = catalog.loadTable(identifier);
}
}
TABLE_CACHE.put(identifier, table);
}
TABLE_CACHE.put(identifier, table);
} else {
// If fetching from cache, refresh the table to avoid working with stale metadata
// (e.g. partition spec)
table.refresh();
}
return table;
}
Expand All @@ -254,15 +261,7 @@ public boolean write(WindowedValue<IcebergDestination> icebergDestination, Row r
icebergDestination,
destination -> {
TableIdentifier identifier = destination.getValue().getTableIdentifier();
Table table;
try {
table =
TABLE_CACHE.get(
identifier, () -> getOrCreateTable(identifier, row.getSchema()));
} catch (ExecutionException e) {
throw new RuntimeException(
"Error while fetching or creating table: " + identifier, e);
}
Table table = getOrCreateTable(identifier, row.getSchema());
return new DestinationState(destination.getValue(), table);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ DataFile createDataFile(Map<Integer, PartitionSpec> partitionSpecs) {
checkStateNotNull(
partitionSpecs.get(getPartitionSpecId()),
"This DataFile was originally created with spec id '%s'. Could not find "
+ "this spec id in table's partition specs: %s.",
+ "this among table's partition specs: %s.",
getPartitionSpecId(),
partitionSpecs.keySet());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public IcebergWriteResult expand(PCollection<KV<String, Row>> input) {

// Commit files to tables
PCollection<KV<String, SnapshotInfo>> snapshots =
writtenFiles.apply(new AppendFilesToTables(catalogConfig));
writtenFiles.apply(new AppendFilesToTables(catalogConfig, filePrefix));

return new IcebergWriteResult(input.getPipeline(), snapshots);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.either;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
Expand Down Expand Up @@ -310,10 +313,7 @@ public void testSerializableDataFileRoundTripEquality() throws IOException {

DataFile roundTripDataFile =
SerializableDataFile.from(datafile, partitionKey)
.createDataFile(
ImmutableMap.<Integer, PartitionSpec>builder()
.put(PARTITION_SPEC.specId(), PARTITION_SPEC)
.build());
.createDataFile(ImmutableMap.of(PARTITION_SPEC.specId(), PARTITION_SPEC));

checkDataFileEquality(datafile, roundTripDataFile);
}
Expand All @@ -329,34 +329,53 @@ public void testSerializableDataFileRoundTripEquality() throws IOException {
*/
@Test
public void testRecreateSerializableDataAfterUpdatingPartitionSpec() throws IOException {
PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA);

Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier());
Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build();
Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build();
// same partition for both records (name_trunc=abc, bool=true)
partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row));

// write some rows
RecordWriter writer =
new RecordWriter(catalog, windowedDestination.getValue(), "test_file_name", partitionKey);
writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row));
writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2));
RecordWriterManager writer =
new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE);
writer.write(windowedDestination, row);
writer.write(windowedDestination, row2);
writer.close();

// fetch data file and its serializable version
DataFile datafile = writer.getDataFile();
SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionKey);

assertEquals(2L, datafile.recordCount());
assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId());
DataFile dataFile =
writer
.getSerializableDataFiles()
.get(windowedDestination)
.get(0)
.createDataFile(table.specs());

// check data file path contains the correct partition components
assertEquals(2L, dataFile.recordCount());
assertEquals(dataFile.specId(), PARTITION_SPEC.specId());
assertThat(dataFile.path().toString(), containsString("name_trunc=abc"));
assertThat(dataFile.path().toString(), containsString("bool=true"));

// table is cached
assertEquals(1, RecordWriterManager.TABLE_CACHE.size());

// update spec
Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier());
table.updateSpec().addField("id").removeField("bool").commit();

Map<Integer, PartitionSpec> updatedSpecs = table.specs();
DataFile roundTripDataFile = serializableDataFile.createDataFile(updatedSpecs);

checkDataFileEquality(datafile, roundTripDataFile);
// write a second data file
// should refresh the table and use the new partition spec
RecordWriterManager writer2 =
new RecordWriterManager(catalog, "test_prefix_2", Long.MAX_VALUE, Integer.MAX_VALUE);
writer2.write(windowedDestination, row);
writer2.write(windowedDestination, row2);
writer2.close();

List<SerializableDataFile> serializableDataFiles =
writer2.getSerializableDataFiles().get(windowedDestination);
assertEquals(2, serializableDataFiles.size());
for (SerializableDataFile serializableDataFile : serializableDataFiles) {
assertEquals(table.spec().specId(), serializableDataFile.getPartitionSpecId());
dataFile = serializableDataFile.createDataFile(table.specs());
assertEquals(1L, dataFile.recordCount());
assertThat(dataFile.path().toString(), containsString("name_trunc=abc"));
assertThat(
dataFile.path().toString(), either(containsString("id=1")).or(containsString("id=2")));
}
}
}

0 comments on commit baba789

Please sign in to comment.