Skip to content

Commit

Permalink
aws#4801 Cancelling origin body to prevent stuck calling thread
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill Chaykin committed Jan 5, 2024
1 parent a5193ea commit 8c0204e
Showing 1 changed file with 82 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.core.exception.NonRetryableException;
import software.amazon.awssdk.core.internal.async.SplittingPublisher;
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
import software.amazon.awssdk.core.internal.util.NoopSubscription;
import software.amazon.awssdk.utils.async.DelegatingSubscriber;
import software.amazon.awssdk.utils.async.InputStreamConsumingPublisher;

/**
Expand All @@ -44,17 +48,17 @@ public final class BlockingInputStreamAsyncRequestBody implements AsyncRequestBo
private final Duration subscribeTimeout;

BlockingInputStreamAsyncRequestBody(Long contentLength) {
this(contentLength, Duration.ofSeconds(10));
this(contentLength, Duration.ofSeconds(10));
}

BlockingInputStreamAsyncRequestBody(Long contentLength, Duration subscribeTimeout) {
this.contentLength = contentLength;
this.subscribeTimeout = subscribeTimeout;
this.contentLength = contentLength;
this.subscribeTimeout = subscribeTimeout;
}

@Override
public Optional<Long> contentLength() {
return Optional.ofNullable(contentLength);
return Optional.ofNullable(contentLength);
}

/**
Expand All @@ -70,46 +74,92 @@ public Optional<Long> contentLength() {
* failed).
*/
public long writeInputStream(InputStream inputStream) {
try {
waitForSubscriptionIfNeeded();
if (contentLength != null) {
return delegate.doBlockingWrite(new SdkLengthAwareInputStream(inputStream, contentLength));
}

return delegate.doBlockingWrite(inputStream);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
delegate.cancel();
throw new RuntimeException(e);
}
try {
waitForSubscriptionIfNeeded();
if (contentLength != null) {
return delegate.doBlockingWrite(new SdkLengthAwareInputStream(inputStream, contentLength));
}

return delegate.doBlockingWrite(inputStream);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
delegate.cancel();
throw new RuntimeException(e);
}
}

/**
* Cancel any running write (and mark the stream as failed).
*/
public void cancel() {
delegate.cancel();
delegate.cancel();
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> s) {
if (subscribeCalled.compareAndSet(false, true)) {
delegate.subscribe(s);
subscribedLatch.countDown();
} else {
s.onSubscribe(new NoopSubscription(s));
s.onError(NonRetryableException.create("A retry was attempted, but AsyncRequestBody.forBlockingInputStream does not "
+ "support retries. Consider using AsyncRequestBody.fromInputStream with an "
+ "input stream that supports mark/reset to get retry support."));
}
if (subscribeCalled.compareAndSet(false, true)) {
delegate.subscribe(s);
subscribedLatch.countDown();
} else {
s.onSubscribe(new NoopSubscription(s));
s.onError(NonRetryableException.create(
"A retry was attempted, but AsyncRequestBody.forBlockingInputStream does not "
+ "support retries. Consider using AsyncRequestBody.fromInputStream with an "
+ "input stream that supports mark/reset to get retry support."));
}
}

@Override
public SdkPublisher<AsyncRequestBody> split(AsyncRequestBodySplitConfiguration splitConfiguration) {
return new BlockingSplittingPublisher(this, splitConfiguration);
}

private void waitForSubscriptionIfNeeded() throws InterruptedException {
long timeoutSeconds = subscribeTimeout.getSeconds();
if (!subscribedLatch.await(timeoutSeconds, TimeUnit.SECONDS)) {
throw new IllegalStateException("The service request was not made within " + timeoutSeconds + " seconds of "
+ "doBlockingWrite being invoked. Make sure to invoke the service request "
+ "BEFORE invoking doBlockingWrite if your caller is single-threaded.");
}
long timeoutSeconds = subscribeTimeout.getSeconds();
if (!subscribedLatch.await(timeoutSeconds, TimeUnit.SECONDS)) {
throw new IllegalStateException("The service request was not made within " + timeoutSeconds + " seconds of "
+ "doBlockingWrite being invoked. Make sure to invoke the service request "
+ "BEFORE invoking doBlockingWrite if your caller is single-threaded.");
}
}

private class BlockingSplittingPublisher extends SplittingPublisher {

public BlockingSplittingPublisher(AsyncRequestBody asyncRequestBody,
AsyncRequestBodySplitConfiguration splitConfiguration) {
super(asyncRequestBody, splitConfiguration);
}

@Override
public void subscribe(Subscriber<? super AsyncRequestBody> downstreamSubscriber) {
Subscriber<? super AsyncRequestBody> delegatingSubscriber = new DelegatingSubscriber<AsyncRequestBody, AsyncRequestBody>(
downstreamSubscriber) {
@Override
public void onSubscribe(Subscription subscription) {
Subscription delegatingSubscription = new Subscription() {
@Override
public void request(long n) {
subscription.request(n);
}

@Override
public void cancel() {
subscription.cancel();

//Cancel origin body to prevent stuck calling thread
BlockingInputStreamAsyncRequestBody.this.cancel();
}
};
super.onSubscribe(delegatingSubscription);
}

@Override
public void onNext(AsyncRequestBody body) {
subscriber.onNext(body);
}
};

super.subscribe(delegatingSubscriber);
}
}
}

0 comments on commit 8c0204e

Please sign in to comment.