From 24efe7619d9a0e5fae70a6d69beabb2d23f362b1 Mon Sep 17 00:00:00 2001 From: martin trieu Date: Thu, 8 Feb 2024 02:00:15 -0800 Subject: [PATCH] Track windmill current active work budget. (#30048) --- .../worker/StreamingDataflowWorker.java | 6 +- .../worker/streaming/ActiveWorkState.java | 79 ++++++++++++++---- .../dataflow/worker/streaming/Work.java | 2 - .../worker/streaming/ActiveWorkStateTest.java | 83 +++++++++++++++++-- 4 files changed, 142 insertions(+), 28 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 463ab953faee..e8ca3a2834f9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -1970,8 +1970,10 @@ public void handleHeartbeatResponses(List failedWork .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new ArrayList<>()) .add( - new FailedTokens( - heartbeatResponse.getWorkToken(), heartbeatResponse.getCacheToken())); + FailedTokens.newBuilder() + .setWorkToken(heartbeatResponse.getWorkToken()) + .setCacheToken(heartbeatResponse.getCacheToken()) + .build()); } } ComputationState state = computationMap.get(computationHeartbeatResponse.getComputationId()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index ff46356d9569..b4b469323932 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; +import com.google.auto.value.AutoValue; import java.io.PrintWriter; import java.util.ArrayDeque; import java.util.Deque; @@ -28,6 +29,7 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Queue; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.stream.Stream; import javax.annotation.Nullable; @@ -38,6 +40,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; @@ -70,11 +73,19 @@ public final class ActiveWorkState { @GuardedBy("this") private final WindmillStateCache.ForComputation computationStateCache; + /** + * Current budget that is being processed or queued on the user worker. Incremented when work is + * activated in {@link #activateWorkForKey(ShardedKey, Work)}, and decremented when work is + * completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, long)}. + */ + private final AtomicReference activeGetWorkBudget; + private ActiveWorkState( Map> activeWork, WindmillStateCache.ForComputation computationStateCache) { this.activeWork = activeWork; this.computationStateCache = computationStateCache; + this.activeGetWorkBudget = new AtomicReference<>(GetWorkBudget.noBudget()); } static ActiveWorkState create(WindmillStateCache.ForComputation computationStateCache) { @@ -88,6 +99,12 @@ static ActiveWorkState forTesting( return new ActiveWorkState(activeWork, computationStateCache); } + private static String elapsedString(Instant start, Instant end) { + Duration activeFor = new Duration(start, end); + // Duration's toString always starts with "PT"; remove that here. + return activeFor.toString().substring(2); + } + /** * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 {@link * ActivateWorkResult} @@ -103,12 +120,12 @@ static ActiveWorkState forTesting( */ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work work) { Deque workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>()); - // This key does not have any work queued up on it. Create one, insert Work, and mark the work // to be executed. if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) { workQueue.addLast(work); activeWork.put(shardedKey, workQueue); + incrementActiveWorkBudget(work); return ActivateWorkResult.EXECUTE; } @@ -121,16 +138,27 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w // Queue the work for later processing. workQueue.addLast(work); + incrementActiveWorkBudget(work); return ActivateWorkResult.QUEUED; } - public static final class FailedTokens { - public long workToken; - public long cacheToken; + @AutoValue + public abstract static class FailedTokens { + public static Builder newBuilder() { + return new AutoValue_ActiveWorkState_FailedTokens.Builder(); + } + + public abstract long workToken(); + + public abstract long cacheToken(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setWorkToken(long value); - public FailedTokens(long workToken, long cacheToken) { - this.workToken = workToken; - this.cacheToken = cacheToken; + public abstract Builder setCacheToken(long value); + + public abstract FailedTokens build(); } } @@ -148,17 +176,17 @@ synchronized void failWorkForKey(Map> failedWork) { for (FailedTokens failedToken : failedTokens) { for (Work queuedWork : entry.getValue()) { WorkItem workItem = queuedWork.getWorkItem(); - if (workItem.getWorkToken() == failedToken.workToken - && workItem.getCacheToken() == failedToken.cacheToken) { + if (workItem.getWorkToken() == failedToken.workToken() + && workItem.getCacheToken() == failedToken.cacheToken()) { LOG.debug( "Failing work " + computationStateCache.getComputation() + " " + entry.getKey().shardingKey() + " " - + failedToken.workToken + + failedToken.workToken() + " " - + failedToken.cacheToken + + failedToken.cacheToken() + ". The work will be retried and is not lost."); queuedWork.setFailed(); break; @@ -168,6 +196,16 @@ synchronized void failWorkForKey(Map> failedWork) { } } + private void incrementActiveWorkBudget(Work work) { + activeGetWorkBudget.updateAndGet( + getWorkBudget -> getWorkBudget.apply(1, work.getWorkItem().getSerializedSize())); + } + + private void decrementActiveWorkBudget(Work work) { + activeGetWorkBudget.updateAndGet( + getWorkBudget -> getWorkBudget.subtract(1, work.getWorkItem().getSerializedSize())); + } + /** * Removes the complete work from the {@link Queue}. The {@link Work} is marked as completed * if its workToken matches the one that is passed in. Returns the next {@link Work} in the {@link @@ -208,6 +246,7 @@ private synchronized void removeCompletedWorkFromQueue( // We consumed the matching work item. workQueue.remove(); + decrementActiveWorkBudget(completedWork); } private synchronized Optional getNextWork(Queue workQueue, ShardedKey shardedKey) { @@ -285,6 +324,15 @@ private static Stream toHeartbeatRequestStream( .build()); } + /** + * Returns the current aggregate {@link GetWorkBudget} that is active on the user worker. Active + * means that the work is received from Windmill, being processed or queued to be processed in + * {@link ActiveWorkState}, and not committed back to Windmill. + */ + GetWorkBudget currentActiveWorkBudget() { + return activeGetWorkBudget.get(); + } + synchronized void printActiveWork(PrintWriter writer, Instant now) { writer.println( ""); } - } - private static String elapsedString(Instant start, Instant end) { - Duration activeFor = new Duration(start, end); - // Duration's toString always starts with "PT"; remove that here. - return activeFor.toString().substring(2); + writer.println("
"); + writer.println("Current Active Work Budget: "); + writer.println(currentActiveWorkBudget()); + writer.println("
"); } enum ActivateWorkResult { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 8d4ba33a1abc..6c85c615af15 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -42,14 +42,12 @@ @NotThreadSafe public class Work implements Runnable { - private final Windmill.WorkItem workItem; private final Supplier clock; private final Instant startTime; private final Map totalDurationPerState; private final Consumer processWorkFn; private TimedState currentState; - private volatile boolean isFailed; private Work(Windmill.WorkItem workItem, Supplier clock, Consumer processWorkFn) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index b384bb03185d..82ff24c03bb8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -32,12 +32,12 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Instant; @@ -50,9 +50,9 @@ @RunWith(JUnit4.class) public class ActiveWorkStateTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private final WindmillStateCache.ForComputation computationStateCache = mock(WindmillStateCache.ForComputation.class); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private Map> readOnlyActiveWork; private ActiveWorkState activeWorkState; @@ -61,11 +61,7 @@ private static ShardedKey shardedKey(String str, long shardKey) { return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); } - private static Work emptyWork() { - return createWork(null); - } - - private static Work createWork(@Nullable Windmill.WorkItem workItem) { + private static Work createWork(Windmill.WorkItem workItem) { return Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {}); } @@ -92,7 +88,8 @@ public void setup() { @Test public void testActivateWorkForKey_EXECUTE_unknownKey() { ActivateWorkResult activateWorkResult = - activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), emptyWork()); + activeWorkState.activateWorkForKey( + shardedKey("someKey", 1L), createWork(createWorkItem(1L))); assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); } @@ -214,6 +211,76 @@ public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() { assertFalse(readOnlyActiveWork.containsKey(shardedKey)); } + @Test + public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneShardKey() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + Work work1 = createWork(createWorkItem(1L)); + Work work2 = createWork(createWorkItem(2L)); + + activeWorkState.activateWorkForKey(shardedKey, work1); + activeWorkState.activateWorkForKey(shardedKey, work2); + + GetWorkBudget expectedActiveBudget1 = + GetWorkBudget.builder() + .setItems(2) + .setBytes( + work1.getWorkItem().getSerializedSize() + work2.getWorkItem().getSerializedSize()) + .build(); + + assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget1); + + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, work1.getWorkItem().getWorkToken()); + + GetWorkBudget expectedActiveBudget2 = + GetWorkBudget.builder() + .setItems(1) + .setBytes(work1.getWorkItem().getSerializedSize()) + .build(); + + assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget2); + } + + @Test + public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_whenWorkCompleted() { + ShardedKey shardedKey = shardedKey("someKey", 1L); + Work work1 = createWork(createWorkItem(1L)); + Work work2 = createWork(createWorkItem(2L)); + + activeWorkState.activateWorkForKey(shardedKey, work1); + activeWorkState.activateWorkForKey(shardedKey, work2); + activeWorkState.completeWorkAndGetNextWorkForKey( + shardedKey, work1.getWorkItem().getWorkToken()); + + GetWorkBudget expectedActiveBudget = + GetWorkBudget.builder() + .setItems(1) + .setBytes(work1.getWorkItem().getSerializedSize()) + .build(); + + assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget); + } + + @Test + public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_multipleShardKeys() { + ShardedKey shardedKey1 = shardedKey("someKey", 1L); + ShardedKey shardedKey2 = shardedKey("someKey", 2L); + Work work1 = createWork(createWorkItem(1L)); + Work work2 = createWork(createWorkItem(2L)); + + activeWorkState.activateWorkForKey(shardedKey1, work1); + activeWorkState.activateWorkForKey(shardedKey2, work2); + + GetWorkBudget expectedActiveBudget = + GetWorkBudget.builder() + .setItems(2) + .setBytes( + work1.getWorkItem().getSerializedSize() + work2.getWorkItem().getSerializedSize()) + .build(); + + assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget); + } + @Test public void testInvalidateStuckCommits() { Map invalidatedCommits = new HashMap<>();