Skip to content

Commit

Permalink
Merge pull request apache#30971: apache#29902 finalize checkpoints af…
Browse files Browse the repository at this point in the history
…ter checkpoint
  • Loading branch information
je-ik authored Apr 15, 2024
2 parents 9fa45df + 2f38932 commit f41f364
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ public List<FlinkSourceSplit<T>> snapshotState(long checkpointId) {
String.format("Failed to get checkpoint for split %d", splitId), e);
}
});
addSplitsToUnfinishedForCheckpoint(checkpointId, splitsState);
return splitsState;
}

Expand Down Expand Up @@ -226,6 +227,16 @@ protected abstract FlinkSourceSplit<T> getReaderCheckpoint(
protected abstract Source.Reader<T> createReader(@Nonnull FlinkSourceSplit<T> sourceSplit)
throws IOException;

/**
* To be overridden in unbounded reader. Notify the reader of created splits that will be part of
* checkpoint. Will be processed during notifyCheckpointComplete to finalize the associated
* CheckpointMarks.
*/
protected void addSplitsToUnfinishedForCheckpoint(
long checkpointId, List<FlinkSourceSplit<T>> splits) {
// nop
}

// ----------------- protected helper methods for subclasses --------------------

protected final Optional<ReaderAndOutput> createAndTrackNextReader() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.io.Serializable;
import org.apache.beam.runners.flink.translation.utils.SerdeUtils;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.flink.api.connector.source.SourceSplit;
import org.apache.flink.core.io.SimpleVersionedSerializer;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand All @@ -38,15 +40,25 @@ public class FlinkSourceSplit<T> implements SourceSplit, Serializable {
private final int splitIndex;
private final Source<T> beamSplitSource;
private final byte @Nullable [] splitState;
private final transient UnboundedSource.@Nullable CheckpointMark checkpointMark;

public FlinkSourceSplit(int splitIndex, Source<T> beamSplitSource) {
this(splitIndex, beamSplitSource, null);
this(splitIndex, beamSplitSource, null, null);
}

public FlinkSourceSplit(int splitIndex, Source<T> beamSplitSource, byte @Nullable [] splitState) {
public FlinkSourceSplit(
int splitIndex,
Source<T> beamSplitSource,
byte @Nullable [] splitState,
UnboundedSource.@Nullable CheckpointMark checkpointMark) {

this.splitIndex = splitIndex;
this.beamSplitSource = beamSplitSource;
this.splitState = splitState;
this.checkpointMark = checkpointMark;

// if we have state, we need checkpoint mark that we will finalize
Preconditions.checkArgument(splitState == null || checkpointMark != null);
}

public int splitIndex() {
Expand All @@ -66,12 +78,17 @@ public String splitId() {
return Integer.toString(splitIndex);
}

public UnboundedSource.@Nullable CheckpointMark getCheckpointMark() {
return checkpointMark;
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("splitIndex", splitIndex)
.add("beamSource", beamSplitSource)
.add("splitState.isNull", splitState == null)
.add("checkpointMark", checkpointMark)
.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -75,6 +78,8 @@ public class FlinkUnboundedSourceReader<T>
private final List<ReaderAndOutput> readers = new ArrayList<>();
private int currentReaderIndex = 0;
private volatile boolean shouldEmitWatermark;
private final NavigableMap<Long, List<FlinkSourceSplit<T>>> unfinishedCheckpoints =
new TreeMap<>();

public FlinkUnboundedSourceReader(
String stepName,
Expand All @@ -94,6 +99,28 @@ protected FlinkUnboundedSourceReader(
super(stepName, executor, context, pipelineOptions, timestampExtractor);
}

@Override
protected void addSplitsToUnfinishedForCheckpoint(
long checkpointId, List<FlinkSourceSplit<T>> flinkSourceSplits) {

unfinishedCheckpoints.put(checkpointId, flinkSourceSplits);
}

@Override
public void notifyCheckpointComplete(long checkpointId) throws Exception {
super.notifyCheckpointComplete(checkpointId);
SortedMap<Long, List<FlinkSourceSplit<T>>> headMap =
unfinishedCheckpoints.headMap(checkpointId + 1);
for (List<FlinkSourceSplit<T>> splits : headMap.values()) {
for (FlinkSourceSplit<T> s : splits) {
finalizeSourceSplit(s.getCheckpointMark());
}
}
for (long checkpoint : new ArrayList<>(headMap.keySet())) {
unfinishedCheckpoints.remove(checkpoint);
}
}

@Override
public void start() {
createPendingBytesGauge(context);
Expand Down Expand Up @@ -199,10 +226,16 @@ protected CompletableFuture<Void> isAvailableForAliveReaders() {
@Override
protected FlinkSourceSplit<T> getReaderCheckpoint(int splitId, ReaderAndOutput readerAndOutput) {
// The checkpoint for unbounded sources is fine granular.
byte[] checkpointState =
getAndEncodeCheckpointMark((UnboundedSource.UnboundedReader<T>) readerAndOutput.reader);
UnboundedSource.UnboundedReader<T> reader =
(UnboundedSource.UnboundedReader<T>) readerAndOutput.reader;
UnboundedSource.CheckpointMark checkpointMark = reader.getCheckpointMark();
@SuppressWarnings("unchecked")
Coder<UnboundedSource.CheckpointMark> coder =
(Coder<UnboundedSource.CheckpointMark>) reader.getCurrentSource().getCheckpointMarkCoder();
byte[] checkpointState = encodeCheckpointMark(coder, checkpointMark);

return new FlinkSourceSplit<>(
splitId, readerAndOutput.reader.getCurrentSource(), checkpointState);
splitId, readerAndOutput.reader.getCurrentSource(), checkpointState, checkpointMark);
}

@Override
Expand Down Expand Up @@ -308,13 +341,9 @@ private void createPendingBytesGauge(SourceReaderContext context) {
});
}

@SuppressWarnings("unchecked")
private <CheckpointMarkT extends UnboundedSource.CheckpointMark>
byte[] getAndEncodeCheckpointMark(UnboundedSource.UnboundedReader<T> reader) {
UnboundedSource<T, CheckpointMarkT> source =
(UnboundedSource<T, CheckpointMarkT>) reader.getCurrentSource();
CheckpointMarkT checkpointMark = (CheckpointMarkT) reader.getCheckpointMark();
Coder<CheckpointMarkT> coder = source.getCheckpointMarkCoder();
private <CheckpointMarkT extends UnboundedSource.CheckpointMark> byte[] encodeCheckpointMark(
Coder<CheckpointMarkT> coder, CheckpointMarkT checkpointMark) {

try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
coder.encode(checkpointMark, baos);
return baos.toByteArray();
Expand All @@ -337,4 +366,11 @@ Source.Reader<T> createUnboundedSourceReader(
}
}
}

private void finalizeSourceSplit(UnboundedSource.@Nullable CheckpointMark mark)
throws IOException {
if (mark != null) {
mark.finalizeCheckpoint();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ public void testSnapshotStateAndRestore() throws Exception {
reader = createReader()) {
pollAndValidate(reader, splits, validatingOutput, numSplits * numRecordsPerSplit / 2);
snapshot = reader.snapshotState(0L);
// use higher checkpoint number to verify that we finalize everything that was created
// up to that checkpoint
reader.notifyCheckpointComplete(1L);
}

assertEquals(numSplits, DummySource.numFinalizeCalled.size());

// Create another reader, add the snapshot splits back.
try (SourceReader<
WindowedValue<ValueWithRecordId<KV<Integer, Integer>>>,
Expand Down Expand Up @@ -298,6 +303,12 @@ public void testPendingBytesMetric() throws Exception {
/** A source whose advance() method only returns true occasionally. */
private static class DummySource extends TestCountingSource {

static List<Integer> numFinalizeCalled = new ArrayList<>();

static {
TestCountingSource.setFinalizeTracker(numFinalizeCalled);
}

public DummySource(int numMessagesPerShard) {
super(numMessagesPerShard);
}
Expand Down

0 comments on commit f41f364

Please sign in to comment.