Skip to content

Commit

Permalink
Track windmill current active work budget. (#30048)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu authored Feb 8, 2024
1 parent 382c6dc commit 24efe76
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1970,8 +1970,10 @@ public void handleHeartbeatResponses(List<Windmill.ComputationHeartbeatResponse>
failedWork
.computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new ArrayList<>())
.add(
new FailedTokens(
heartbeatResponse.getWorkToken(), heartbeatResponse.getCacheToken()));
FailedTokens.newBuilder()
.setWorkToken(heartbeatResponse.getWorkToken())
.setCacheToken(heartbeatResponse.getCacheToken())
.build());
}
}
ComputationState state = computationMap.get(computationHeartbeatResponse.getComputationId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;

import com.google.auto.value.AutoValue;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Deque;
Expand All @@ -28,6 +29,7 @@
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
Expand All @@ -38,6 +40,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
Expand Down Expand Up @@ -70,11 +73,19 @@ public final class ActiveWorkState {
@GuardedBy("this")
private final WindmillStateCache.ForComputation computationStateCache;

/**
* Current budget that is being processed or queued on the user worker. Incremented when work is
* activated in {@link #activateWorkForKey(ShardedKey, Work)}, and decremented when work is
* completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, long)}.
*/
private final AtomicReference<GetWorkBudget> activeGetWorkBudget;

private ActiveWorkState(
Map<ShardedKey, Deque<Work>> activeWork,
WindmillStateCache.ForComputation computationStateCache) {
this.activeWork = activeWork;
this.computationStateCache = computationStateCache;
this.activeGetWorkBudget = new AtomicReference<>(GetWorkBudget.noBudget());
}

static ActiveWorkState create(WindmillStateCache.ForComputation computationStateCache) {
Expand All @@ -88,6 +99,12 @@ static ActiveWorkState forTesting(
return new ActiveWorkState(activeWork, computationStateCache);
}

private static String elapsedString(Instant start, Instant end) {
Duration activeFor = new Duration(start, end);
// Duration's toString always starts with "PT"; remove that here.
return activeFor.toString().substring(2);
}

/**
* Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 {@link
* ActivateWorkResult}
Expand All @@ -103,12 +120,12 @@ static ActiveWorkState forTesting(
*/
synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work work) {
Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>());

// This key does not have any work queued up on it. Create one, insert Work, and mark the work
// to be executed.
if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
workQueue.addLast(work);
activeWork.put(shardedKey, workQueue);
incrementActiveWorkBudget(work);
return ActivateWorkResult.EXECUTE;
}

Expand All @@ -121,16 +138,27 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w

// Queue the work for later processing.
workQueue.addLast(work);
incrementActiveWorkBudget(work);
return ActivateWorkResult.QUEUED;
}

public static final class FailedTokens {
public long workToken;
public long cacheToken;
@AutoValue
public abstract static class FailedTokens {
public static Builder newBuilder() {
return new AutoValue_ActiveWorkState_FailedTokens.Builder();
}

public abstract long workToken();

public abstract long cacheToken();

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setWorkToken(long value);

public FailedTokens(long workToken, long cacheToken) {
this.workToken = workToken;
this.cacheToken = cacheToken;
public abstract Builder setCacheToken(long value);

public abstract FailedTokens build();
}
}

Expand All @@ -148,17 +176,17 @@ synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
for (FailedTokens failedToken : failedTokens) {
for (Work queuedWork : entry.getValue()) {
WorkItem workItem = queuedWork.getWorkItem();
if (workItem.getWorkToken() == failedToken.workToken
&& workItem.getCacheToken() == failedToken.cacheToken) {
if (workItem.getWorkToken() == failedToken.workToken()
&& workItem.getCacheToken() == failedToken.cacheToken()) {
LOG.debug(
"Failing work "
+ computationStateCache.getComputation()
+ " "
+ entry.getKey().shardingKey()
+ " "
+ failedToken.workToken
+ failedToken.workToken()
+ " "
+ failedToken.cacheToken
+ failedToken.cacheToken()
+ ". The work will be retried and is not lost.");
queuedWork.setFailed();
break;
Expand All @@ -168,6 +196,16 @@ synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
}
}

private void incrementActiveWorkBudget(Work work) {
activeGetWorkBudget.updateAndGet(
getWorkBudget -> getWorkBudget.apply(1, work.getWorkItem().getSerializedSize()));
}

private void decrementActiveWorkBudget(Work work) {
activeGetWorkBudget.updateAndGet(
getWorkBudget -> getWorkBudget.subtract(1, work.getWorkItem().getSerializedSize()));
}

/**
* Removes the complete work from the {@link Queue<Work>}. The {@link Work} is marked as completed
* if its workToken matches the one that is passed in. Returns the next {@link Work} in the {@link
Expand Down Expand Up @@ -208,6 +246,7 @@ private synchronized void removeCompletedWorkFromQueue(

// We consumed the matching work item.
workQueue.remove();
decrementActiveWorkBudget(completedWork);
}

private synchronized Optional<Work> getNextWork(Queue<Work> workQueue, ShardedKey shardedKey) {
Expand Down Expand Up @@ -285,6 +324,15 @@ private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
.build());
}

/**
* Returns the current aggregate {@link GetWorkBudget} that is active on the user worker. Active
* means that the work is received from Windmill, being processed or queued to be processed in
* {@link ActiveWorkState}, and not committed back to Windmill.
*/
GetWorkBudget currentActiveWorkBudget() {
return activeGetWorkBudget.get();
}

synchronized void printActiveWork(PrintWriter writer, Instant now) {
writer.println(
"<table border=\"1\" "
Expand Down Expand Up @@ -328,12 +376,11 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
writer.println(commitsPendingCount - MAX_PRINTABLE_COMMIT_PENDING_KEYS);
writer.println("<br>");
}
}

private static String elapsedString(Instant start, Instant end) {
Duration activeFor = new Duration(start, end);
// Duration's toString always starts with "PT"; remove that here.
return activeFor.toString().substring(2);
writer.println("<br>");
writer.println("Current Active Work Budget: ");
writer.println(currentActiveWorkBudget());
writer.println("<br>");
}

enum ActivateWorkResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@

@NotThreadSafe
public class Work implements Runnable {

private final Windmill.WorkItem workItem;
private final Supplier<Instant> clock;
private final Instant startTime;
private final Map<Windmill.LatencyAttribution.State, Duration> totalDurationPerState;
private final Consumer<Work> processWorkFn;
private TimedState currentState;

private volatile boolean isFailed;

private Work(Windmill.WorkItem workItem, Supplier<Instant> clock, Consumer<Work> processWorkFn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Instant;
Expand All @@ -50,9 +50,9 @@

@RunWith(JUnit4.class)
public class ActiveWorkStateTest {
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private final WindmillStateCache.ForComputation computationStateCache =
mock(WindmillStateCache.ForComputation.class);
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private Map<ShardedKey, Deque<Work>> readOnlyActiveWork;

private ActiveWorkState activeWorkState;
Expand All @@ -61,11 +61,7 @@ private static ShardedKey shardedKey(String str, long shardKey) {
return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey);
}

private static Work emptyWork() {
return createWork(null);
}

private static Work createWork(@Nullable Windmill.WorkItem workItem) {
private static Work createWork(Windmill.WorkItem workItem) {
return Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {});
}

Expand All @@ -92,7 +88,8 @@ public void setup() {
@Test
public void testActivateWorkForKey_EXECUTE_unknownKey() {
ActivateWorkResult activateWorkResult =
activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), emptyWork());
activeWorkState.activateWorkForKey(
shardedKey("someKey", 1L), createWork(createWorkItem(1L)));

assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
}
Expand Down Expand Up @@ -214,6 +211,76 @@ public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() {
assertFalse(readOnlyActiveWork.containsKey(shardedKey));
}

@Test
public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneShardKey() {
ShardedKey shardedKey = shardedKey("someKey", 1L);
Work work1 = createWork(createWorkItem(1L));
Work work2 = createWork(createWorkItem(2L));

activeWorkState.activateWorkForKey(shardedKey, work1);
activeWorkState.activateWorkForKey(shardedKey, work2);

GetWorkBudget expectedActiveBudget1 =
GetWorkBudget.builder()
.setItems(2)
.setBytes(
work1.getWorkItem().getSerializedSize() + work2.getWorkItem().getSerializedSize())
.build();

assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget1);

activeWorkState.completeWorkAndGetNextWorkForKey(
shardedKey, work1.getWorkItem().getWorkToken());

GetWorkBudget expectedActiveBudget2 =
GetWorkBudget.builder()
.setItems(1)
.setBytes(work1.getWorkItem().getSerializedSize())
.build();

assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget2);
}

@Test
public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_whenWorkCompleted() {
ShardedKey shardedKey = shardedKey("someKey", 1L);
Work work1 = createWork(createWorkItem(1L));
Work work2 = createWork(createWorkItem(2L));

activeWorkState.activateWorkForKey(shardedKey, work1);
activeWorkState.activateWorkForKey(shardedKey, work2);
activeWorkState.completeWorkAndGetNextWorkForKey(
shardedKey, work1.getWorkItem().getWorkToken());

GetWorkBudget expectedActiveBudget =
GetWorkBudget.builder()
.setItems(1)
.setBytes(work1.getWorkItem().getSerializedSize())
.build();

assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget);
}

@Test
public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_multipleShardKeys() {
ShardedKey shardedKey1 = shardedKey("someKey", 1L);
ShardedKey shardedKey2 = shardedKey("someKey", 2L);
Work work1 = createWork(createWorkItem(1L));
Work work2 = createWork(createWorkItem(2L));

activeWorkState.activateWorkForKey(shardedKey1, work1);
activeWorkState.activateWorkForKey(shardedKey2, work2);

GetWorkBudget expectedActiveBudget =
GetWorkBudget.builder()
.setItems(2)
.setBytes(
work1.getWorkItem().getSerializedSize() + work2.getWorkItem().getSerializedSize())
.build();

assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget);
}

@Test
public void testInvalidateStuckCommits() {
Map<ShardedKey, Long> invalidatedCommits = new HashMap<>();
Expand Down

0 comments on commit 24efe76

Please sign in to comment.