diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BlockingInputStreamAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BlockingInputStreamAsyncRequestBody.java index 9b19907ce36b..ecbaab923101 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BlockingInputStreamAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BlockingInputStreamAsyncRequestBody.java @@ -22,11 +22,15 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; + import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.core.internal.async.SplittingPublisher; import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream; import software.amazon.awssdk.core.internal.util.NoopSubscription; +import software.amazon.awssdk.utils.async.DelegatingSubscriber; import software.amazon.awssdk.utils.async.InputStreamConsumingPublisher; /** @@ -44,17 +48,17 @@ public final class BlockingInputStreamAsyncRequestBody implements AsyncRequestBo private final Duration subscribeTimeout; BlockingInputStreamAsyncRequestBody(Long contentLength) { - this(contentLength, Duration.ofSeconds(10)); + this(contentLength, Duration.ofSeconds(10)); } BlockingInputStreamAsyncRequestBody(Long contentLength, Duration subscribeTimeout) { - this.contentLength = contentLength; - this.subscribeTimeout = subscribeTimeout; + this.contentLength = contentLength; + this.subscribeTimeout = subscribeTimeout; } @Override public Optional contentLength() { - return Optional.ofNullable(contentLength); + return Optional.ofNullable(contentLength); } /** @@ -70,46 +74,92 @@ public Optional contentLength() { * failed). */ public long writeInputStream(InputStream inputStream) { - try { - waitForSubscriptionIfNeeded(); - if (contentLength != null) { - return delegate.doBlockingWrite(new SdkLengthAwareInputStream(inputStream, contentLength)); - } - - return delegate.doBlockingWrite(inputStream); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - delegate.cancel(); - throw new RuntimeException(e); - } + try { + waitForSubscriptionIfNeeded(); + if (contentLength != null) { + return delegate.doBlockingWrite(new SdkLengthAwareInputStream(inputStream, contentLength)); + } + + return delegate.doBlockingWrite(inputStream); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + delegate.cancel(); + throw new RuntimeException(e); + } } /** * Cancel any running write (and mark the stream as failed). */ public void cancel() { - delegate.cancel(); + delegate.cancel(); } @Override public void subscribe(Subscriber s) { - if (subscribeCalled.compareAndSet(false, true)) { - delegate.subscribe(s); - subscribedLatch.countDown(); - } else { - s.onSubscribe(new NoopSubscription(s)); - s.onError(NonRetryableException.create("A retry was attempted, but AsyncRequestBody.forBlockingInputStream does not " - + "support retries. Consider using AsyncRequestBody.fromInputStream with an " - + "input stream that supports mark/reset to get retry support.")); - } + if (subscribeCalled.compareAndSet(false, true)) { + delegate.subscribe(s); + subscribedLatch.countDown(); + } else { + s.onSubscribe(new NoopSubscription(s)); + s.onError(NonRetryableException.create( + "A retry was attempted, but AsyncRequestBody.forBlockingInputStream does not " + + "support retries. Consider using AsyncRequestBody.fromInputStream with an " + + "input stream that supports mark/reset to get retry support.")); + } + } + + @Override + public SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { + return new BlockingSplittingPublisher(this, splitConfiguration); } private void waitForSubscriptionIfNeeded() throws InterruptedException { - long timeoutSeconds = subscribeTimeout.getSeconds(); - if (!subscribedLatch.await(timeoutSeconds, TimeUnit.SECONDS)) { - throw new IllegalStateException("The service request was not made within " + timeoutSeconds + " seconds of " - + "doBlockingWrite being invoked. Make sure to invoke the service request " - + "BEFORE invoking doBlockingWrite if your caller is single-threaded."); - } + long timeoutSeconds = subscribeTimeout.getSeconds(); + if (!subscribedLatch.await(timeoutSeconds, TimeUnit.SECONDS)) { + throw new IllegalStateException("The service request was not made within " + timeoutSeconds + " seconds of " + + "doBlockingWrite being invoked. Make sure to invoke the service request " + + "BEFORE invoking doBlockingWrite if your caller is single-threaded."); + } + } + + private class BlockingSplittingPublisher extends SplittingPublisher { + + public BlockingSplittingPublisher(AsyncRequestBody asyncRequestBody, + AsyncRequestBodySplitConfiguration splitConfiguration) { + super(asyncRequestBody, splitConfiguration); + } + + @Override + public void subscribe(Subscriber downstreamSubscriber) { + Subscriber delegatingSubscriber = new DelegatingSubscriber( + downstreamSubscriber) { + @Override + public void onSubscribe(Subscription subscription) { + Subscription delegatingSubscription = new Subscription() { + @Override + public void request(long n) { + subscription.request(n); + } + + @Override + public void cancel() { + subscription.cancel(); + + //Cancel origin body to prevent stuck calling thread + BlockingInputStreamAsyncRequestBody.this.cancel(); + } + }; + super.onSubscribe(delegatingSubscription); + } + + @Override + public void onNext(AsyncRequestBody body) { + subscriber.onNext(body); + } + }; + + super.subscribe(delegatingSubscriber); + } } }