Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Oct 24, 2024
1 parent 6f052de commit fb8573a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,86 +225,89 @@ private void flushInternal(Map<Long, PendingRequest> requests) {
}

private void issueSingleRequest(long id, PendingRequest pendingRequest) {
if (prepareForSend(id, pendingRequest)) {
StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
requestBuilder
.addCommitChunkBuilder()
.setComputationId(pendingRequest.computationId())
.setRequestId(id)
.setShardingKey(pendingRequest.shardingKey())
.setSerializedWorkItemCommit(pendingRequest.serializedCommit());
StreamingCommitWorkRequest chunk = requestBuilder.build();
try {
send(chunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
} else {
if (!prepareForSend(id, pendingRequest)) {
pendingRequest.abort();
return;
}

StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
requestBuilder
.addCommitChunkBuilder()
.setComputationId(pendingRequest.computationId())
.setRequestId(id)
.setShardingKey(pendingRequest.shardingKey())
.setSerializedWorkItemCommit(pendingRequest.serializedCommit());
StreamingCommitWorkRequest chunk = requestBuilder.build();
try {
send(chunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
}

private void issueBatchedRequest(Map<Long, PendingRequest> requests) {
if (prepareForSend(requests)) {
StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
String lastComputation = null;
for (Map.Entry<Long, PendingRequest> entry : requests.entrySet()) {
PendingRequest request = entry.getValue();
StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder();
if (lastComputation == null || !lastComputation.equals(request.computationId())) {
chunkBuilder.setComputationId(request.computationId());
lastComputation = request.computationId();
}
chunkBuilder
.setRequestId(entry.getKey())
.setShardingKey(request.shardingKey())
.setSerializedWorkItemCommit(request.serializedCommit());
}
StreamingCommitWorkRequest request = requestBuilder.build();
try {
send(request);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
} else {
if (!prepareForSend(requests)) {
requests.forEach((ignored, pendingRequest) -> pendingRequest.abort());
return;
}

StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
String lastComputation = null;
for (Map.Entry<Long, PendingRequest> entry : requests.entrySet()) {
PendingRequest request = entry.getValue();
StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder();
if (lastComputation == null || !lastComputation.equals(request.computationId())) {
chunkBuilder.setComputationId(request.computationId());
lastComputation = request.computationId();
}
chunkBuilder
.setRequestId(entry.getKey())
.setShardingKey(request.shardingKey())
.setSerializedWorkItemCommit(request.serializedCommit());
}
StreamingCommitWorkRequest request = requestBuilder.build();
try {
send(request);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
}

private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) {
if (prepareForSend(id, pendingRequest)) {
checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId.");
ByteString serializedCommit = pendingRequest.serializedCommit();
synchronized (this) {
for (int i = 0;
i < serializedCommit.size();
i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE;
ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size()));

StreamingCommitRequestChunk.Builder chunkBuilder =
StreamingCommitRequestChunk.newBuilder()
.setRequestId(id)
.setSerializedWorkItemCommit(chunk)
.setComputationId(pendingRequest.computationId())
.setShardingKey(pendingRequest.shardingKey());
int remaining = serializedCommit.size() - end;
if (remaining > 0) {
chunkBuilder.setRemainingBytesForWorkItem(remaining);
}

StreamingCommitWorkRequest requestChunk =
StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build();
try {
send(requestChunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
break;
}
if (!prepareForSend(id, pendingRequest)) {
pendingRequest.abort();
return;
}

checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId.");
ByteString serializedCommit = pendingRequest.serializedCommit();
synchronized (this) {
for (int i = 0;
i < serializedCommit.size();
i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE;
ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size()));

StreamingCommitRequestChunk.Builder chunkBuilder =
StreamingCommitRequestChunk.newBuilder()
.setRequestId(id)
.setSerializedWorkItemCommit(chunk)
.setComputationId(pendingRequest.computationId())
.setShardingKey(pendingRequest.shardingKey());
int remaining = serializedCommit.size() - end;
if (remaining > 0) {
chunkBuilder.setRemainingBytesForWorkItem(remaining);
}

StreamingCommitWorkRequest requestChunk =
StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build();
try {
send(requestChunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
break;
}
}
} else {
pendingRequest.abort();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class GrpcGetDataStream
private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST =
StreamingGetDataRequest.newBuilder().build();

/** @implNote insertion and removal is guarded by {@link #shutdownLock} */
/** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@link #shutdownLock} */
private final Deque<QueuedBatch> batches;

private final Map<Long, AppendableInputStream> pending;
Expand Down Expand Up @@ -349,21 +349,36 @@ private <ResponseT> ResponseT issueRequest(QueuedRequest request, ParseFn<Respon
"Cannot send request=[" + request + "] on closed stream.");
}

private void handleShutdown(QueuedRequest request, Throwable cause) {
private void handleShutdown(QueuedRequest request, Throwable... causes) {
if (isShutdown()) {
WindmillStreamShutdownException shutdownException =
new WindmillStreamShutdownException(
"Cannot send request=[" + request + "] on closed stream.");
shutdownException.addSuppressed(cause);

for (Throwable cause : causes) {
shutdownException.addSuppressed(cause);
}

throw shutdownException;
}
}

private void handleShutdown(QueuedBatch batch) {
if (isShutdown()) {
throw new WindmillStreamShutdownException(
"Stream was closed when attempting to send " + batch.requestsCount() + " requests.");
}
}

private void queueRequestAndWait(QueuedRequest request) throws InterruptedException {
QueuedBatch batch;
boolean responsibleForSend = false;
@Nullable QueuedBatch prevBatch = null;
synchronized (shutdownLock) {
if (isShutdown()) {
handleShutdown(request);
}

batch = batches.isEmpty() ? null : batches.getLast();
if (batch == null
|| batch.isFinalized()
Expand All @@ -389,6 +404,10 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept
// Finalize the batch so that no additional requests will be added. Leave the batch in the
// queue so that a subsequent batch will wait for its completion.
synchronized (shutdownLock) {
if (isShutdown()) {
handleShutdown(batch);
}

verify(batch == batches.peekFirst(), "GetDataStream request batch removed before send().");
batch.markFinalized();
}
Expand All @@ -403,6 +422,10 @@ void trySendBatch(QueuedBatch batch) {
try {
sendBatch(batch);
synchronized (shutdownLock) {
if (isShutdown()) {
handleShutdown(batch);
}

verify(
batch == batches.pollFirst(),
"Sent GetDataStream request batch removed before send() was complete.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder build
builder.addGlobalDataRequest(dataRequest.global());
}
}

@Override
public final String toString() {
return "QueuedRequest{" + "dataRequest=" + dataRequest + ", id=" + id + '}';
}
}

/**
Expand Down

0 comments on commit fb8573a

Please sign in to comment.