Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Oct 29, 2024
1 parent 93c5b31 commit 77969d0
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,15 @@
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.joda.time.DateTime;
import org.joda.time.Instant;
import org.slf4j.Logger;

Expand Down Expand Up @@ -74,9 +65,9 @@ public abstract class AbstractWindmillStream<RequestT, ResponseT> implements Win
// Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce
// per-chunk overhead, and small enough that we can still perform granular flow-control.
protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20;
// Indicates that the logical stream has been half-closed and is waiting for clean server
// shutdown.
private static final Status OK_STATUS = Status.fromCode(Status.Code.OK);

protected final AtomicBoolean clientClosed;
protected final Sleeper sleeper;

/**
Expand All @@ -87,30 +78,25 @@ public abstract class AbstractWindmillStream<RequestT, ResponseT> implements Win
*/
protected final Object shutdownLock = new Object();

private final AtomicLong lastSendTimeMs;
private final Logger logger;
private final ExecutorService executor;
private final BackOff backoff;
private final AtomicLong startTimeMs;
private final AtomicLong lastResponseTimeMs;
private final AtomicInteger restartCount;
private final AtomicInteger errorCount;
private final AtomicReference<String> lastRestartReason;
private final AtomicReference<DateTime> lastRestartTime;
private final AtomicLong sleepUntil;
private final CountDownLatch finishLatch;
private final Set<AbstractWindmillStream<?, ?>> streamRegistry;
private final int logEveryNStreamFailures;
private final String backendWorkerToken;
private final ResettableRequestObserver<RequestT> requestObserver;
private final AtomicReference<DateTime> shutdownTime;
private final ResettableStreamObserver<RequestT> requestObserver;
private final StreamDebugMetrics debugMetrics;
protected volatile boolean clientClosed;

/**
* Indicates if the current {@link ResettableRequestObserver} was closed by calling {@link
* #halfClose()}.
* Indicates if the current {@link ResettableStreamObserver} was closed by calling {@link
* #halfClose()}. Separate from {@link #clientClosed} as this is specific to the requestObserver
* and is initially false on retry.
*/
private final AtomicBoolean streamClosed;
@GuardedBy("this")
private boolean streamClosed;

private final Logger logger;
private volatile boolean isShutdown;
private volatile boolean started;

Expand All @@ -133,28 +119,20 @@ protected AbstractWindmillStream(
this.backoff = backoff;
this.streamRegistry = streamRegistry;
this.logEveryNStreamFailures = logEveryNStreamFailures;
this.clientClosed = new AtomicBoolean();
this.clientClosed = false;
this.isShutdown = false;
this.started = false;
this.streamClosed = new AtomicBoolean(false);
this.startTimeMs = new AtomicLong();
this.lastSendTimeMs = new AtomicLong();
this.lastResponseTimeMs = new AtomicLong();
this.restartCount = new AtomicInteger();
this.errorCount = new AtomicInteger();
this.lastRestartReason = new AtomicReference<>();
this.lastRestartTime = new AtomicReference<>();
this.sleepUntil = new AtomicLong();
this.streamClosed = false;
this.finishLatch = new CountDownLatch(1);
this.requestObserver =
new ResettableRequestObserver<>(
new ResettableStreamObserver<>(
() ->
streamObserverFactory.from(
clientFactory,
new AbstractWindmillStream<RequestT, ResponseT>.ResponseObserver()));
this.sleeper = Sleeper.DEFAULT;
this.logger = logger;
this.shutdownTime = new AtomicReference<>();
this.debugMetrics = new StreamDebugMetrics();
}

private static String createThreadName(String streamType, String backendWorkerToken) {
Expand All @@ -163,10 +141,6 @@ private static String createThreadName(String streamType, String backendWorkerTo
: String.format("%s-WindmillStream-thread", streamType);
}

private static long debugDuration(long nowMs, long startMs) {
return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs);
}

/** Called on each response from the server. */
protected abstract void onResponse(ResponseT response);

Expand Down Expand Up @@ -195,13 +169,13 @@ protected final void send(RequestT request) {
return;
}

if (streamClosed.get()) {
if (streamClosed) {
// TODO(m-trieu): throw a more specific exception here (i.e StreamClosedException)
throw new IllegalStateException("Send called on a client closed stream.");
}

try {
lastSendTimeMs.set(Instant.now().getMillis());
debugMetrics.recordSend();
requestObserver.onNext(request);
} catch (StreamObserverCancelledException e) {
if (isShutdown) {
Expand Down Expand Up @@ -239,21 +213,22 @@ private void startStream() {
if (isShutdown) {
break;
}
startTimeMs.set(Instant.now().getMillis());
lastResponseTimeMs.set(0);
streamClosed.set(false);
debugMetrics.recordStart();
streamClosed = false;
requestObserver.reset();
onNewStream();
if (clientClosed.get()) {
if (clientClosed) {
halfClose();
}
return;
}
} catch (WindmillStreamShutdownException e) {
logger.debug("Stream was shutdown waiting to start.", e);
} catch (Exception e) {
logger.error("Failed to create new stream, retrying: ", e);
try {
long sleep = backoff.nextBackOffMillis();
sleepUntil.set(Instant.now().getMillis() + sleep);
debugMetrics.recordSleep(sleep);
sleeper.sleep(sleep);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
Expand Down Expand Up @@ -285,7 +260,7 @@ protected final void executeSafely(Runnable runnable) {
}

public final void maybeSendHealthCheck(Instant lastSendThreshold) {
if (!clientClosed.get() && lastSendTimeMs.get() < lastSendThreshold.getMillis()) {
if (!clientClosed && debugMetrics.lastSendTimeMs() < lastSendThreshold.getMillis()) {
try {
sendHealthCheck();
} catch (RuntimeException e) {
Expand All @@ -303,28 +278,19 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) {
*/
public final void appendSummaryHtml(PrintWriter writer) {
appendSpecificHtml(writer);
if (restartCount.get() > 0) {
writer.format(
", %d restarts, last restart reason [ %s ] at [%s], %d errors",
restartCount.get(), lastRestartReason.get(), lastRestartTime.get(), errorCount.get());
}
if (clientClosed.get()) {
debugMetrics.printRestartsHtml(writer);
if (clientClosed) {
writer.write(", client closed");
}
long nowMs = Instant.now().getMillis();
long sleepLeft = sleepUntil.get() - nowMs;
long sleepLeft = debugMetrics.sleepLeft();
if (sleepLeft > 0) {
writer.format(", %dms backoff remaining", sleepLeft);
}
debugMetrics.printSummaryHtml(writer, nowMs);
writer.format(
", current stream is %dms old, last send %dms, last response %dms, closed: %s, "
+ "isShutdown: %s, shutdown time: %s",
debugDuration(nowMs, startTimeMs.get()),
debugDuration(nowMs, lastSendTimeMs.get()),
debugDuration(nowMs, lastResponseTimeMs.get()),
streamClosed.get(),
isShutdown,
shutdownTime.get());
", closed: %s, " + "isShutdown: %s, shutdown time: %s",
streamClosed, isShutdown, debugMetrics.shutdownTime());
}

/**
Expand All @@ -336,9 +302,9 @@ public final void appendSummaryHtml(PrintWriter writer) {
@Override
public final synchronized void halfClose() {
// Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
clientClosed.set(true);
clientClosed = true;
requestObserver.onCompleted();
streamClosed.set(true);
streamClosed = true;
}

@Override
Expand All @@ -348,7 +314,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte

@Override
public final Instant startTime() {
return new Instant(startTimeMs.get());
return new Instant(debugMetrics.startTimeMs());
}

@Override
Expand All @@ -363,71 +329,15 @@ public final void shutdown() {
synchronized (shutdownLock) {
if (!isShutdown) {
isShutdown = true;
shutdownTime.set(DateTime.now());
if (started) {
// requestObserver is not set until the first startStream() is called. If the stream was
// never started there is nothing to clean up internally.
requestObserver.onError(
new WindmillStreamShutdownException("Explicit call to shutdown stream."));
shutdownInternal();
}
debugMetrics.recordShutdown();
requestObserver.poison();
shutdownInternal();
}
}
}

private void recordRestartReason(String error) {
lastRestartReason.set(error);
lastRestartTime.set(DateTime.now());
}

protected abstract void shutdownInternal();

/**
* Request observer that allows resetting its internal delegate using the given {@link
* #requestObserverSupplier}.
*
* @implNote {@link StreamObserver}s generated by {@link * #requestObserverSupplier} are expected
* to be {@link ThreadSafe}.
*/
@ThreadSafe
private static class ResettableRequestObserver<RequestT> implements StreamObserver<RequestT> {

private final Supplier<StreamObserver<RequestT>> requestObserverSupplier;

@GuardedBy("this")
private @Nullable StreamObserver<RequestT> delegateRequestObserver;

private ResettableRequestObserver(Supplier<StreamObserver<RequestT>> requestObserverSupplier) {
this.requestObserverSupplier = requestObserverSupplier;
this.delegateRequestObserver = null;
}

private synchronized StreamObserver<RequestT> delegate() {
return Preconditions.checkNotNull(
delegateRequestObserver,
"requestObserver cannot be null. Missing a call to startStream() to initialize.");
}

private synchronized void reset() {
delegateRequestObserver = requestObserverSupplier.get();
}

@Override
public void onNext(RequestT requestT) {
delegate().onNext(requestT);
}

@Override
public void onError(Throwable throwable) {
delegate().onError(throwable);
}

@Override
public void onCompleted() {
delegate().onCompleted();
}
}

private class ResponseObserver implements StreamObserver<ResponseT> {

@Override
Expand All @@ -437,7 +347,7 @@ public void onNext(ResponseT response) {
} catch (IOException e) {
// Ignore.
}
lastResponseTimeMs.set(Instant.now().getMillis());
debugMetrics.recordResponse();
onResponse(response);
}

Expand All @@ -451,7 +361,7 @@ public void onError(Throwable t) {

try {
long sleep = backoff.nextBackOffMillis();
sleepUntil.set(Instant.now().getMillis() + sleep);
debugMetrics.recordSleep(sleep);
sleeper.sleep(sleep);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
Expand All @@ -473,16 +383,16 @@ public void onCompleted() {
}

private void recordStreamStatus(Status status) {
int currentRestartCount = restartCount.incrementAndGet();
int currentRestartCount = debugMetrics.incrementAndGetRestarts();
if (status.isOk()) {
String restartReason =
"Stream completed successfully but did not complete requested operations, "
+ "recreating";
logger.warn(restartReason);
recordRestartReason(restartReason);
debugMetrics.recordRestartReason(restartReason);
} else {
int currentErrorCount = errorCount.incrementAndGet();
recordRestartReason(status.toString());
int currentErrorCount = debugMetrics.incrementAndGetErrors();
debugMetrics.recordRestartReason(status.toString());
Throwable t = status.getCause();
if (t instanceof StreamObserverCancelledException) {
logger.error(
Expand All @@ -494,11 +404,6 @@ private void recordStreamStatus(Status status) {
} else if (currentRestartCount % logEveryNStreamFailures == 0) {
// Don't log every restart since it will get noisy, and many errors transient.
long nowMillis = Instant.now().getMillis();
String responseDebug =
lastResponseTimeMs.get() == 0
? "never received response"
: "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago";

logger.debug(
"{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}"
+ " with status: {}. created {}ms ago; {}. This is normal with autoscaling.",
Expand All @@ -507,8 +412,8 @@ private void recordStreamStatus(Status status) {
currentErrorCount,
t,
status,
nowMillis - startTimeMs.get(),
responseDebug);
nowMillis - debugMetrics.startTimeMs(),
debugMetrics.responseDebugString(nowMillis));
}

// If the stream was stopped due to a resource exhausted error then we are throttled.
Expand All @@ -520,7 +425,7 @@ private void recordStreamStatus(Status status) {

/** Returns true if the stream was torn down and should not be restarted internally. */
private synchronized boolean maybeTeardownStream() {
if (isShutdown || (clientClosed.get() && !hasPendingRequests())) {
if (isShutdown || (clientClosed && !hasPendingRequests())) {
streamRegistry.remove(AbstractWindmillStream.this);
finishLatch.countDown();
executor.shutdownNow();
Expand Down
Loading

0 comments on commit 77969d0

Please sign in to comment.