diff --git a/build.sbt b/build.sbt index cf1f3897..0ef75941 100644 --- a/build.sbt +++ b/build.sbt @@ -32,7 +32,8 @@ lazy val mimaSettings = mimaDefaultSettings ++ Seq( mimaBinaryIssueFilters ++= Seq( ProblemFilters.exclude[DirectMissingMethodProblem]("play.libs.ws.ahc.StandaloneAhcWSResponse.getBodyAsSource"), ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package$"), - ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package") + ProblemFilters.exclude[MissingClassProblem]("play.api.libs.ws.package"), + ProblemFilters.exclude[DirectMissingMethodProblem]("play.api.libs.ws.ahc.DefaultStreamedAsyncHandler.this") ) ) diff --git a/play-ahc-ws-standalone/src/main/java/play/libs/ws/ahc/StandaloneAhcWSClient.java b/play-ahc-ws-standalone/src/main/java/play/libs/ws/ahc/StandaloneAhcWSClient.java index cad0c08c..5634303f 100644 --- a/play-ahc-ws-standalone/src/main/java/play/libs/ws/ahc/StandaloneAhcWSClient.java +++ b/play-ahc-ws-standalone/src/main/java/play/libs/ws/ahc/StandaloneAhcWSClient.java @@ -4,31 +4,36 @@ package play.libs.ws.ahc; +import akka.Done; import akka.stream.Materializer; import akka.stream.javadsl.Source; import akka.util.ByteString; import akka.util.ByteStringBuilder; import com.typesafe.sslconfig.ssl.SystemConfiguration; import com.typesafe.sslconfig.ssl.debug.DebugConfiguration; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.LoggerFactory; -import play.api.libs.ws.ahc.AhcConfigBuilder; -import play.api.libs.ws.ahc.AhcLoggerFactory; -import play.api.libs.ws.ahc.AhcWSClientConfig; -import play.api.libs.ws.ahc.DefaultStreamedAsyncHandler; +import play.api.libs.ws.ahc.*; import play.api.libs.ws.ahc.cache.AhcHttpCache; import play.api.libs.ws.ahc.cache.CachingAsyncHttpClient; import play.libs.ws.StandaloneWSClient; import play.libs.ws.StandaloneWSResponse; import play.shaded.ahc.org.asynchttpclient.*; +import scala.Function1; import scala.compat.java8.FutureConverters; +import scala.compat.java8.FutureConverters$; import scala.concurrent.ExecutionContext; import scala.concurrent.Future; import scala.concurrent.Promise; +import scala.util.Try; import javax.inject.Inject; import java.io.IOException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; +import java.util.function.Function; /** * A WS asyncHttpClient backed by an AsyncHttpClient instance. @@ -92,33 +97,73 @@ public void onThrowable(Throwable t) { } CompletionStage executeStream(Request request, ExecutionContext ec) { - final Promise scalaPromise = scala.concurrent.Promise$.MODULE$.apply(); + final Promise streamStarted = scala.concurrent.Promise$.MODULE$.apply(); + final Promise streamCompletion = scala.concurrent.Promise$.MODULE$.apply(); + + Function f = state -> { + Publisher publisher = state.publisher(); + Publisher wrap = new Publisher() { + @Override + public void subscribe(Subscriber s) { + publisher.subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription sub) { + s.onSubscribe(sub); + } + + @Override + public void onNext(HttpResponseBodyPart httpResponseBodyPart) { + s.onNext(httpResponseBodyPart); + } + + @Override + public void onError(Throwable t) { + s.onError(t); + } + + @Override + public void onComplete() { + FutureConverters$.MODULE$.toJava(streamCompletion.future()) + .handle((d, t) -> { + if (d != null) s.onComplete(); + else s.onError(t); + return null; + }); + } + } + ); + } + }; + + return new StreamedResponse(this, + state.statusCode(), + state.statusText(), + state.uriOption().get(), + state.responseHeaders(), + wrap); + }; - asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler<>(state -> - new StreamedResponse(this, - state.statusCode(), - state.statusText(), - state.uriOption().get(), - state.responseHeaders(), - state.publisher()), - scalaPromise)); - return FutureConverters.toJava(scalaPromise.future()); + asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler<>(f, + streamStarted, + streamCompletion + )); + return FutureConverters.toJava(streamStarted.future()); } /** * A convenience method for creating a StandaloneAhcWSClient from configuration. * * @param ahcWSClientConfig the configuration object - * @param materializer an akka materializer + * @param materializer an akka materializer * @return a fully configured StandaloneAhcWSClient instance. - * * @see #create(AhcWSClientConfig, AhcHttpCache, Materializer) */ public static StandaloneAhcWSClient create(AhcWSClientConfig ahcWSClientConfig, Materializer materializer) { return create( - ahcWSClientConfig, - null /* no cache*/, - materializer + ahcWSClientConfig, + null /* no cache*/, + materializer ); } @@ -161,10 +206,10 @@ public static StandaloneAhcWSClient create(AhcWSClientConfig ahcWSClientConfig, ByteString blockingToByteString(Source bodyAsSource) { try { return bodyAsSource - .runFold(ByteString.createBuilder(), ByteStringBuilder::append, materializer) - .thenApply(ByteStringBuilder::result) - .toCompletableFuture() - .get(); + .runFold(ByteString.createBuilder(), ByteStringBuilder::append, materializer) + .thenApply(ByteStringBuilder::result) + .toCompletableFuture() + .get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(e); } diff --git a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala index 741861b3..f5c77d9a 100644 --- a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala +++ b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala @@ -3,21 +3,32 @@ */ package play.api.libs.ws.ahc +import akka.Done import javax.inject.Inject - import akka.stream.Materializer import akka.stream.scaladsl.Source import akka.util.ByteString import com.typesafe.sslconfig.ssl.SystemConfiguration import com.typesafe.sslconfig.ssl.debug.DebugConfiguration +import org.reactivestreams.Publisher +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription import play.api.libs.ws.ahc.cache._ -import play.api.libs.ws.{ EmptyBody, StandaloneWSClient, StandaloneWSRequest } +import play.api.libs.ws.EmptyBody +import play.api.libs.ws.StandaloneWSClient +import play.api.libs.ws.StandaloneWSRequest import play.shaded.ahc.org.asynchttpclient.uri.Uri -import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse, _ } +import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse } +import play.shaded.ahc.org.asynchttpclient._ +import java.util.function.{ Function => JFunction } import scala.collection.immutable.TreeMap -import scala.compat.java8.FunctionConverters -import scala.concurrent.{ Await, Future, Promise } +import scala.compat.java8.FunctionConverters._ +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Success /** * A WS client backed by an AsyncHttpClient. @@ -29,7 +40,10 @@ import scala.concurrent.{ Await, Future, Promise } * also close asyncHttpClient. * @param materializer A materializer, meant to execute the stream */ -class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implicit materializer: Materializer) extends StandaloneWSClient { +class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)( + implicit + materializer: Materializer +) extends StandaloneWSClient { /** Returns instance of AsyncHttpClient */ def underlying[T]: T = asyncHttpClient.asInstanceOf[T] @@ -58,7 +72,9 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici ) } - private[ahc] def execute(request: Request): Future[StandaloneAhcWSResponse] = { + private[ahc] def execute( + request: Request + ): Future[StandaloneAhcWSResponse] = { val result = Promise[StandaloneAhcWSResponse]() val handler = new AsyncCompletionHandler[AHCResponse]() { override def onCompleted(response: AHCResponse): AHCResponse = { @@ -88,30 +104,71 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici } private[ahc] def executeStream(request: Request): Future[StreamedResponse] = { - val promise = Promise[StreamedResponse]() - - val function = FunctionConverters.asJavaFunction[StreamedState, StreamedResponse](state => - new StreamedResponse( - this, - state.statusCode, - state.statusText, - state.uriOption.get, - state.responseHeaders, - state.publisher) + val streamStarted = Promise[StreamedResponse]() + val streamCompletion = Promise[Done]() + + val client = this + + val function: JFunction[StreamedState, StreamedResponse] = { + state: StreamedState => + val publisher = state.publisher + + val wrap = new Publisher[HttpResponseBodyPart]() { + override def subscribe( + s: Subscriber[_ >: HttpResponseBodyPart] + ): Unit = { + publisher.subscribe(new Subscriber[HttpResponseBodyPart] { + override def onSubscribe(sub: Subscription): Unit = + s.onSubscribe(sub) + + override def onNext(t: HttpResponseBodyPart): Unit = s.onNext(t) + + override def onError(t: Throwable): Unit = s.onError(t) + + override def onComplete(): Unit = { + streamCompletion.future.onComplete { + case Success(_) => s.onComplete() + case Failure(t) => s.onError(t) + }(materializer.executionContext) + } + }) + } + + } + new StreamedResponse( + client, + state.statusCode, + state.statusText, + state.uriOption.get, + state.responseHeaders, + wrap + ) + + }.asJava + asyncHttpClient.executeRequest( + request, + new DefaultStreamedAsyncHandler[StreamedResponse]( + function, + streamStarted, + streamCompletion + ) ) - asyncHttpClient.executeRequest(request, new DefaultStreamedAsyncHandler[StreamedResponse](function, promise)) - promise.future + streamStarted.future } private[ahc] def blockingToByteString(bodyAsSource: Source[ByteString, _]) = { - StandaloneAhcWSClient.logger.warn(s"blockingToByteString is a blocking and unsafe operation!") + StandaloneAhcWSClient.logger.warn( + s"blockingToByteString is a blocking and unsafe operation!" + ) import scala.concurrent.ExecutionContext.Implicits.global val limitedSource = bodyAsSource.limit(StandaloneAhcWSClient.elementLimit) - val result = limitedSource.runFold(ByteString.createBuilder) { (acc, bs) => - acc.append(bs) - }.map(_.result()) + val result = limitedSource + .runFold(ByteString.createBuilder) { (acc, bs) => + acc.append(bs) + } + .map(_.result()) Await.result(result, StandaloneAhcWSClient.blockingTimeout) } @@ -125,7 +182,9 @@ object StandaloneAhcWSClient { val elementLimit = 13 // 13 8192k blocks is roughly 100k private val logger = org.slf4j.LoggerFactory.getLogger(this.getClass) - private[ahc] val loggerFactory = new AhcLoggerFactory(org.slf4j.LoggerFactory.getILoggerFactory) + private[ahc] val loggerFactory = new AhcLoggerFactory( + org.slf4j.LoggerFactory.getILoggerFactory + ) /** * Convenient factory method that uses a play.api.libs.ws.WSClientConfig value for configuration instead of @@ -146,21 +205,26 @@ object StandaloneAhcWSClient { * @param httpCache if not null, will be used for HTTP response caching. * @param materializer the akka materializer. */ - def apply(config: AhcWSClientConfig = AhcWSClientConfigFactory.forConfig(), httpCache: Option[AhcHttpCache] = None)(implicit materializer: Materializer): StandaloneAhcWSClient = { + def apply( + config: AhcWSClientConfig = AhcWSClientConfigFactory.forConfig(), + httpCache: Option[AhcHttpCache] = None + )(implicit materializer: Materializer): StandaloneAhcWSClient = { if (config.wsClientConfig.ssl.debug.enabled) { - new DebugConfiguration(StandaloneAhcWSClient.loggerFactory).configure(config.wsClientConfig.ssl.debug) + new DebugConfiguration(StandaloneAhcWSClient.loggerFactory) + .configure(config.wsClientConfig.ssl.debug) } val ahcConfig = new AhcConfigBuilder(config).build() val asyncHttpClient = new DefaultAsyncHttpClient(ahcConfig) val wsClient = new StandaloneAhcWSClient( - httpCache.map { cache => - new CachingAsyncHttpClient(asyncHttpClient, cache) - }.getOrElse { - asyncHttpClient - } + httpCache + .map { cache => + new CachingAsyncHttpClient(asyncHttpClient, cache) + } + .getOrElse { + asyncHttpClient + } ) new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl) wsClient } } - diff --git a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/Streamed.scala b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/Streamed.scala index f09e2559..c4025cc9 100644 --- a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/Streamed.scala +++ b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/Streamed.scala @@ -5,7 +5,8 @@ package play.api.libs.ws.ahc import java.net.URI -import org.reactivestreams.{ Publisher, Subscriber, Subscription } +import akka.Done +import org.reactivestreams.{ Subscriber, Subscription, Publisher } import play.shaded.ahc.org.asynchttpclient.AsyncHandler.State import play.shaded.ahc.org.asynchttpclient._ import play.shaded.ahc.org.asynchttpclient.handler.StreamedAsyncHandler @@ -20,14 +21,14 @@ case class StreamedState( publisher: Publisher[HttpResponseBodyPart] = EmptyPublisher ) -class DefaultStreamedAsyncHandler[T](f: java.util.function.Function[StreamedState, T], promise: Promise[T]) extends StreamedAsyncHandler[Unit] with AhcUtilities { +class DefaultStreamedAsyncHandler[T](f: java.util.function.Function[StreamedState, T], streamStarted: Promise[T], streamDone: Promise[Done]) extends StreamedAsyncHandler[Unit] with AhcUtilities { private var state = StreamedState() def onStream(publisher: Publisher[HttpResponseBodyPart]): State = { if (this.state.publisher != EmptyPublisher) State.ABORT else { this.state = state.copy(publisher = publisher) - promise.success(f(state)) + streamStarted.success(f(state)) State.CONTINUE } } @@ -58,15 +59,20 @@ class DefaultStreamedAsyncHandler[T](f: java.util.function.Function[StreamedStat override def onCompleted(): Unit = { // EmptyPublisher can be replaces with `Source.empty` when we carry out the refactoring // mentioned in the `execute2` method. - promise.trySuccess(f(state.copy(publisher = EmptyPublisher))) + streamStarted.trySuccess(f(state.copy(publisher = EmptyPublisher))) + streamDone.trySuccess(Done) } - override def onThrowable(t: Throwable): Unit = promise.tryFailure(t) + override def onThrowable(t: Throwable): Unit = { + streamStarted.tryFailure(t) + streamDone.tryFailure(t) + } } private case object EmptyPublisher extends Publisher[HttpResponseBodyPart] { def subscribe(s: Subscriber[_ >: HttpResponseBodyPart]): Unit = { - if (s eq null) throw new NullPointerException("Subscriber must not be null, rule 1.9") + if (s eq null) + throw new NullPointerException("Subscriber must not be null, rule 1.9") s.onSubscribe(CancelledSubscription) s.onComplete() }