From 70d44ac0c08740d0d1b463cff58eda7ad0c92036 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 18 Mar 2024 15:51:07 +0100 Subject: [PATCH 01/16] SimpleSubscriber simplified even more --- build.sbt | 3 +- .../netty/internal/NettyRequestBody.scala | 46 ++++++------ .../reactivestreams/SimpleSubscriber.scala | 70 +++++++------------ 3 files changed, 51 insertions(+), 68 deletions(-) diff --git a/build.sbt b/build.sbt index 68155a7ace..e9f8347bf6 100644 --- a/build.sbt +++ b/build.sbt @@ -121,7 +121,8 @@ val commonJvmSettings: Seq[Def.Setting[_]] = commonSettings ++ Seq( case Some((2, _)) => Seq("-target:jvm-1.8") // some users are on java 8 case _ => Seq.empty[String] } - } + }, + run / fork := true ) // run JS tests inside Gecko, due to jsdom not supporting fetch and to avoid having to install node diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index fc11b33db4..87d3a8bd8d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -52,26 +52,29 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody */ def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = { - bodyType match { - case RawBodyType.StringBody(charset) => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset))) - case RawBodyType.ByteArrayBody => - readAllBytes(serverRequest, maxBytes).map(RawValue(_)) - case RawBodyType.ByteBufferBody => - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) - case RawBodyType.InputStreamBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) - case RawBodyType.InputStreamRangeBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) - case RawBodyType.FileBody => - for { - file <- createFile(serverRequest) - _ <- writeToFile(serverRequest, file, maxBytes) - } yield RawValue(FileRange(file), Seq(FileRange(file))) - case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException()) - } + override def toRaw[RAW]( + serverRequest: ServerRequest, + bodyType: RawBodyType[RAW], + maxBytes: Option[Long] + ): F[RawValue[RAW]] = bodyType match { + case RawBodyType.StringBody(charset) => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.ByteArrayBody => + readAllBytes(serverRequest, maxBytes).map(RawValue(_)) + case RawBodyType.ByteBufferBody => + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) + case RawBodyType.InputStreamBody => + // Possibly can be optimized to avoid loading all data eagerly into memory + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) + case RawBodyType.InputStreamRangeBody => + // Possibly can be optimized to avoid loading all data eagerly into memory + readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + case RawBodyType.FileBody => + for { + file <- createFile(serverRequest) + _ <- writeToFile(serverRequest, file, maxBytes) + } yield RawValue(FileRange(file), Seq(FileRange(file))) + case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException) } private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = @@ -81,6 +84,7 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case req: StreamedHttpRequest => val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt) publisherToBytes(req, contentLength, maxBytes) - case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + case other => + monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 5408274822..2ef846dc7c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -6,17 +6,22 @@ import org.reactivestreams.{Publisher, Subscription} import sttp.capabilities.StreamMaxLengthExceededException import java.util.concurrent.LinkedBlockingQueue -import scala.concurrent.{Future, Promise} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.util.Success private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends PromisingSubscriber[Array[Byte], HttpContent] { + // These don't need to be volatile as Reactive Streams guarantees that onSubscribe/onNext/onError/onComplete are + // called serially (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#1-publisher-code - rule 3) + // The only other methods are `future` and `resultBlocking` which are protected against any memory visibility issues + // by the result Promise. private var subscription: Subscription = _ + private var buffers = Vector[ByteBuf]() + private var totalLength = 0 + private val resultPromise = Promise[Array[Byte]]() - @volatile private var totalLength = 0 - private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]](1) - @volatile private var buffers = Vector[ByteBuf]() override def future: Future[Array[Byte]] = resultPromise.future - def resultBlocking(): Either[Throwable, Array[Byte]] = resultBlockingQueue.take() override def onSubscribe(s: Subscription): Unit = { subscription = s @@ -25,16 +30,16 @@ private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends Promis override def onNext(content: HttpContent): Unit = { val byteBuf = content.content() - // If expected content length is known, and we receive exactly this amount of bytes, we assume there's only one chunk and - // we can immediately return it without going through the buffer list. - if (contentLength.contains(byteBuf.readableBytes())) { + // If expected content length is known, we haven't received any data yet, and we receive exactly this amount of bytes, + // we assume there's only one chunk and we can immediately return it without going through the buffer list. + if (buffers.isEmpty && contentLength.contains(byteBuf.readableBytes())) { val finalArray = ByteBufUtil.getBytes(byteBuf) byteBuf.release() - if (!resultBlockingQueue.offer(Right(finalArray))) { - // Queue full, which is unexpected. The previous chunk was supposed the be the only one. A malformed request perhaps? + if (!resultPromise.trySuccess(finalArray)) { + // Result is set, which is unexpected. The previous chunk was supposed the be the only one. + // A malformed request perhaps? subscription.cancel() } else { - resultPromise.success(finalArray) subscription.request(1) } } else { @@ -49,12 +54,11 @@ private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends Promis val _ = buf.release() } buffers = Vector.empty - resultBlockingQueue.offer(Left(t)) resultPromise.failure(t) } override def onComplete(): Unit = { - if (!buffers.isEmpty) { + if (buffers.nonEmpty) { val mergedArray = new Array[Byte](totalLength) var currentIndex = 0 buffers.foreach { buf => @@ -64,11 +68,7 @@ private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends Promis val _ = buf.release() } buffers = Vector.empty - if (!resultBlockingQueue.offer(Right(mergedArray))) { - // Result queue full, which is unexpected. - resultPromise.failure(new IllegalStateException("Calling onComplete after result was already returned")) - } else - resultPromise.success(mergedArray) + resultPromise.success(mergedArray) } else { () // result already sent in onNext } @@ -80,37 +80,15 @@ object SimpleSubscriber { def processAll(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] = maxBytes match { - case Some(max) if (contentLength.exists(_ > max)) => Future.failed(StreamMaxLengthExceededException(max)) - case Some(max) => { - val subscriber = new SimpleSubscriber(contentLength) - publisher.subscribe(new LimitedLengthSubscriber(max, subscriber)) - subscriber.future - } - case None => { + case Some(max) if contentLength.exists(_ > max) => + Future.failed(StreamMaxLengthExceededException(max)) + case _ => val subscriber = new SimpleSubscriber(contentLength) - publisher.subscribe(subscriber) + val maybeLimitedSubscriber = maxBytes.fold(subscriber)(new LimitedLengthSubscriber(_, subscriber)) + publisher.subscribe(maybeLimitedSubscriber) subscriber.future - } } def processAllBlocking(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] = - maxBytes match { - case Some(max) if (contentLength.exists(_ > max)) => throw new StreamMaxLengthExceededException(max) - case Some(max) => { - val subscriber = new SimpleSubscriber(contentLength) - publisher.subscribe(new LimitedLengthSubscriber(max, subscriber)) - subscriber.resultBlocking() match { - case Right(result) => result - case Left(e) => throw e - } - } - case None => { - val subscriber = new SimpleSubscriber(contentLength) - publisher.subscribe(subscriber) - subscriber.resultBlocking() match { - case Right(result) => result - case Left(e) => throw e - } - } - } + Await.result(processAll(publisher, contentLength, maxBytes), Duration.Inf) } From 9f2883e36bd5b9cd55964ae78e87947ea8817928 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 09:15:37 +0100 Subject: [PATCH 02/16] FileWriterSubscriber simplified --- .../reactivestreams/FileWriterSubscriber.scala | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala index 28dfbb6d87..7047ca26ef 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/FileWriterSubscriber.scala @@ -5,8 +5,9 @@ import org.reactivestreams.{Publisher, Subscription} import java.nio.channels.AsynchronousFileChannel import java.nio.file.{Path, StandardOpenOption} -import scala.concurrent.{Future, Promise} import java.util.concurrent.LinkedBlockingQueue +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future, Promise} /** A Reactive Streams subscriber which receives chunks of bytes and writes them to a file. */ @@ -22,11 +23,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon /** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */ private val resultPromise = Promise[Unit]() - /** An alternative way to signal completion, so that non-effectful servers can await on the response (like netty-loom) */ - private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]() - override def future: Future[Unit] = resultPromise.future - private def waitForResultBlocking(): Either[Throwable, Unit] = resultBlockingQueue.take() override def onSubscribe(s: Subscription): Unit = { this.subscription = s @@ -58,13 +55,11 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon override def onError(t: Throwable): Unit = { fileChannel.close() - resultBlockingQueue.add(Left(t)) resultPromise.failure(t) } override def onComplete(): Unit = { fileChannel.close() - val _ = resultBlockingQueue.add(Right(())) resultPromise.success(()) } } @@ -76,9 +71,6 @@ object FileWriterSubscriber { subscriber.future } - def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = { - val subscriber = new FileWriterSubscriber(path) - publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)) - subscriber.waitForResultBlocking().left.foreach(e => throw e) - } + def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = + Await.result(processAll(publisher, path, maxBytes), Duration.Inf) } From 79ecbe7b968f997a50aada57d044731be27fca5b Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 09:40:10 +0100 Subject: [PATCH 03/16] contentLength in NettyRequestBody changed to Long --- .../netty/cats/internal/NettyCatsRequestBody.scala | 2 +- .../tapir/server/netty/loom/NettyIdRequestBody.scala | 2 +- .../netty/internal/NettyFutureRequestBody.scala | 2 +- .../server/netty/internal/NettyRequestBody.scala | 4 ++-- .../internal/reactivestreams/SimpleSubscriber.scala | 11 ++++------- .../netty/zio/internal/NettyZioRequestBody.scala | 2 +- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index e81f66fd00..4dfe6c22fb 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -20,7 +20,7 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( override implicit val monad: MonadError[F] = new CatsMonadError() - override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] = + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] = streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index a14dccae8c..7a9373e6c3 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -16,7 +16,7 @@ private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFi override implicit val monad: MonadError[Id] = idMonad override val streams: capabilities.Streams[NoStreams] = NoStreams - override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] = + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] = SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 41b4e023ca..edaef8e524 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -18,7 +18,7 @@ private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Fut override val streams: capabilities.Streams[NoStreams] = NoStreams override implicit val monad: MonadError[Future] = new FutureMonad() - override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] = + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Future[Array[Byte]] = SimpleSubscriber.processAll(publisher, contentLength, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 87d3a8bd8d..78f01c0ea0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -37,7 +37,7 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody * @return * An effect which finishes with a single array of all collected bytes. */ - def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] + def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file. * @@ -82,7 +82,7 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request monad.unit(Array.empty[Byte]) case req: StreamedHttpRequest => - val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt) + val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong) publisherToBytes(req, contentLength, maxBytes) case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 2ef846dc7c..ef220e41ef 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -8,13 +8,10 @@ import sttp.capabilities.StreamMaxLengthExceededException import java.util.concurrent.LinkedBlockingQueue import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.util.Success -private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends PromisingSubscriber[Array[Byte], HttpContent] { +private[netty] class SimpleSubscriber(contentLength: Option[Long]) extends PromisingSubscriber[Array[Byte], HttpContent] { // These don't need to be volatile as Reactive Streams guarantees that onSubscribe/onNext/onError/onComplete are // called serially (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#1-publisher-code - rule 3) - // The only other methods are `future` and `resultBlocking` which are protected against any memory visibility issues - // by the result Promise. private var subscription: Subscription = _ private var buffers = Vector[ByteBuf]() private var totalLength = 0 @@ -78,17 +75,17 @@ private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends Promis object SimpleSubscriber { - def processAll(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] = + def processAll(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Future[Array[Byte]] = maxBytes match { case Some(max) if contentLength.exists(_ > max) => Future.failed(StreamMaxLengthExceededException(max)) case _ => val subscriber = new SimpleSubscriber(contentLength) - val maybeLimitedSubscriber = maxBytes.fold(subscriber)(new LimitedLengthSubscriber(_, subscriber)) + val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber) publisher.subscribe(maybeLimitedSubscriber) subscriber.future } - def processAllBlocking(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] = + def processAllBlocking(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] = Await.result(processAll(publisher, contentLength, maxBytes), Duration.Inf) } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala index d7fcedb1d1..3cb9b9ab21 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala @@ -21,7 +21,7 @@ private[zio] class NettyZioRequestBody[Env]( override def publisherToBytes( publisher: Publisher[HttpContent], - contentLength: Option[Int], + contentLength: Option[Long], maxBytes: Option[Long] ): RIO[Env, Array[Byte]] = streamCompatible.fromPublisher(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) From 9e6ddea5cc959450fe94f0c1acf848a336d8aabe Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 10:24:56 +0100 Subject: [PATCH 04/16] accepting explicit RunAsync in InputStreamPublisher --- .../netty/loom/NettyIdServerInterpreter.scala | 9 ++------- .../netty/NettyFutureServerInterpreter.scala | 8 ++------ .../netty/internal/NettyToResponseBody.scala | 6 ++++-- .../server/netty/internal/RunAsync.scala | 14 +++++++++++++ .../InputStreamPublisher.scala | 20 ++++++++++++------- 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala index 973e8df1d2..059f772382 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServerInterpreter.scala @@ -13,14 +13,9 @@ trait NettyIdServerInterpreter { ses, nettyServerOptions.interceptors, new NettyIdRequestBody(nettyServerOptions.createFile), - new NettyToResponseBody[Id], + new NettyToResponseBody[Id](RunAsync.Id), nettyServerOptions.deleteFile, - new RunAsync[Id] { - override def apply(f: => Id[Unit]): Unit = { - val _ = f - () - } - } + RunAsync.Id ) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 6124e5c35f..59db8c796e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -22,9 +22,9 @@ trait NettyFutureServerInterpreter { ses, nettyServerOptions.interceptors, new NettyFutureRequestBody(nettyServerOptions.createFile), - new NettyToResponseBody[Future](), + new NettyToResponseBody[Future](RunAsync.Future), nettyServerOptions.deleteFile, - FutureRunAsync + RunAsync.Future ) } } @@ -35,8 +35,4 @@ object NettyFutureServerInterpreter { override def nettyServerOptions: NettyFutureServerOptions = serverOptions } } - - private object FutureRunAsync extends RunAsync[Future] { - override def apply(f: => Future[Unit]): Unit = f - } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index e0b2c0b35e..79f3685850 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -23,7 +23,9 @@ import java.nio.charset.Charset * Publishers to integrate responses like InputStreamBody, InputStreamRangeBody or FileBody with Netty reactive extensions. Other kinds of * raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. */ -private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] { +private[netty] class NettyToResponseBody[F[_]](runAsync: RunAsync[F])(implicit me: MonadError[F]) + extends ToResponseBody[NettyResponse, NoStreams] { + override val streams: capabilities.Streams[NoStreams] = NoStreams override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { @@ -54,7 +56,7 @@ private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) exten } private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = { - new InputStreamPublisher[F](streamRange, DefaultChunkSize) + new InputStreamPublisher[F](streamRange, DefaultChunkSize, runAsync) } private def wrap(fileRange: FileRange): Publisher[HttpContent] = { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/RunAsync.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/RunAsync.scala index 691552fea2..acfbce28c0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/RunAsync.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/RunAsync.scala @@ -1,5 +1,19 @@ package sttp.tapir.server.netty.internal +import scala.concurrent.Future + trait RunAsync[F[_]] { def apply(f: => F[Unit]): Unit } +object RunAsync { + type Id[A] = A + + final val Id: RunAsync[Id] = new RunAsync[Id] { + override def apply(f: => Id[Unit]): Unit = f + } + + final val Future: RunAsync[Future] = new RunAsync[Future] { + override def apply(f: => Future[Unit]): Unit = + f: Unit + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala index 7f16c0a108..5c1da6bea7 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -3,15 +3,22 @@ package sttp.tapir.server.netty.internal.reactivestreams import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.monad.MonadError +import sttp.monad.syntax._ import sttp.tapir.InputStreamRange +import sttp.tapir.server.netty.internal.RunAsync import java.io.InputStream import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.Try -import sttp.monad.MonadError -import sttp.monad.syntax._ -class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implicit monad: MonadError[F]) extends Publisher[HttpContent] { +class InputStreamPublisher[F[_]]( + range: InputStreamRange, + chunkSize: Int, + runAsync: RunAsync[F] +)(implicit + monad: MonadError[F] +) extends Publisher[HttpContent] { override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = { if (subscriber == null) throw new NullPointerException("Subscriber cannot be null") val subscription = new InputStreamSubscription(subscriber, range, chunkSize) @@ -46,7 +53,7 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic case _ => chunkSize } - val _ = monad + runAsync(monad .blocking( stream.readNBytes(expectedBytes) ) @@ -69,11 +76,10 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic } } .handleError { - case e => { + case e => val _ = Try(stream.close()) monad.unit(subscriber.onError(e)) - } - } + }) } } From 6bacd318c60c218287c08ae2d777a0ffa85f1f9b Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 10:28:58 +0100 Subject: [PATCH 05/16] removed accidentally committed change --- build.sbt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index e9f8347bf6..68155a7ace 100644 --- a/build.sbt +++ b/build.sbt @@ -121,8 +121,7 @@ val commonJvmSettings: Seq[Def.Setting[_]] = commonSettings ++ Seq( case Some((2, _)) => Seq("-target:jvm-1.8") // some users are on java 8 case _ => Seq.empty[String] } - }, - run / fork := true + } ) // run JS tests inside Gecko, due to jsdom not supporting fetch and to avoid having to install node From 58e3b92ad6f0eb56f18b99e2625285238a9b2ff8 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 10:30:12 +0100 Subject: [PATCH 06/16] fixed bad import --- .../sttp/tapir/server/netty/NettyFutureServerInterpreter.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 59db8c796e..7800066b0d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -2,7 +2,6 @@ package sttp.tapir.server.netty import sttp.monad.FutureMonad import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync} import scala.concurrent.{ExecutionContext, Future} From e11f2d4733a52ea697e88674627199ad682ec14a Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 11:05:37 +0100 Subject: [PATCH 07/16] added comment about Id in InputStreamPublisher --- .../netty/internal/reactivestreams/InputStreamPublisher.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala index 5c1da6bea7..a96e297ee4 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/InputStreamPublisher.scala @@ -53,6 +53,8 @@ class InputStreamPublisher[F[_]]( case _ => chunkSize } + // Note: the effect F may be Id, in which case everything here will be synchronous and blocking + // (which technically is against the reactive streams spec). runAsync(monad .blocking( stream.readNBytes(expectedBytes) From 221013c549e7399746b40e849b5d239a18926139 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 14:40:43 +0100 Subject: [PATCH 08/16] SubscriberInputStream implementation --- .../netty/internal/NettyRequestBody.scala | 37 +++-- .../SubscriberInputStream.scala | 136 ++++++++++++++++++ .../SubscriberInputStreamTest.scala | 95 ++++++++++++ 3 files changed, 254 insertions(+), 14 deletions(-) create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala create mode 100644 server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 78f01c0ea0..88712cce15 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -4,19 +4,17 @@ import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher +import sttp.capabilities.Streams +import sttp.model.HeaderNames import sttp.monad.MonadError import sttp.monad.syntax._ +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.RequestBody -import sttp.tapir.RawBodyType -import sttp.tapir.TapirFile -import sttp.tapir.server.interpreter.RawValue -import sttp.tapir.FileRange -import sttp.tapir.InputStreamRange -import java.io.ByteArrayInputStream +import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.server.netty.internal.reactivestreams.SubscriberInputStream + +import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer -import sttp.capabilities.Streams -import sttp.model.HeaderNames /** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { @@ -64,17 +62,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case RawBodyType.ByteBufferBody => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) case RawBodyType.InputStreamBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) + monad.eval(RawValue(readAsStream(serverRequest, maxBytes))) case RawBodyType.InputStreamRangeBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + monad.unit(RawValue(InputStreamRange(() => readAsStream(serverRequest, maxBytes)))) case RawBodyType.FileBody => for { file <- createFile(serverRequest) _ <- writeToFile(serverRequest, file, maxBytes) } yield RawValue(FileRange(file), Seq(FileRange(file))) - case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException) + case _: RawBodyType.MultipartBody => + monad.error(new UnsupportedOperationException) } private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = @@ -87,4 +84,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")) } + + private def readAsStream(serverRequest: ServerRequest, maxBytes: Option[Long]): InputStream = { + serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request + InputStream.nullInputStream() + case req: StreamedHttpRequest => + val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong) + SubscriberInputStream.processAsStream(req, contentLength, maxBytes) + case other => + throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}") + } + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala new file mode 100644 index 0000000000..62a99c9539 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala @@ -0,0 +1,136 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.ByteBuf +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.capabilities.StreamMaxLengthExceededException + +import java.io.{IOException, InputStream} +import java.util.concurrent.LinkedBlockingQueue +import scala.annotation.tailrec +import scala.concurrent.Promise + +/** A blocking input stream that reads from a reactive streams publisher of [[HttpContent]]. + * @param maxBufferedChunks + * maximum number of unread chunks that can be buffered before blocking the publisher + */ +private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) + extends InputStream with Subscriber[HttpContent] { + + require(maxBufferedChunks > 0) + + import SubscriberInputStream._ + + // volatile because used in both InputStream & Subscriber methods + @volatile private[this] var closed = false + // Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec + // (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7) + // because they are called both from InputStream & Subscriber methods. + private[this] var subscription: Subscription = _ + private[this] var currentItem: Item = _ + // the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher + private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error + + private def readItem(blocking: Boolean): Item = { + if (currentItem eq null) { + currentItem = if (blocking) queue.take() else queue.poll() + currentItem match { + case _: Chunk => synchronized(subscription.request(1)) + case _ => + } + } + currentItem + } + + override def available(): Int = + readItem(blocking = false) match { + case Chunk(data) => data.readableBytes() + case _ => 0 + } + + override def read(): Int = { + val buffer = new Array[Byte](1) + if (read(buffer) == -1) -1 else buffer(0) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = + if (closed) throw new IOException("Stream closed") + else if (len == 0) 0 + else + readItem(blocking = true) match { + case Chunk(data) => + val toRead = Math.min(len, data.readableBytes()) + data.readBytes(b, off, toRead) + if (data.readableBytes() == 0) { + data.release() + currentItem = null + } + toRead + case Error(cause) => throw cause + case End => -1 + } + + override def close(): Unit = if (!closed) { + synchronized(subscription.cancel()) + closed = true + clearQueue() + } + + @tailrec private def clearQueue(): Unit = + queue.poll() match { + case Chunk(data) => + data.release() + clearQueue() + case _ => + } + + override def onSubscribe(s: Subscription): Unit = synchronized { + if (s eq null) { + throw new NullPointerException("Subscription must not be null") + } + subscription = s + subscription.request(maxBufferedChunks) + } + + override def onNext(chunk: HttpContent): Unit = { + if (!queue.offer(Chunk(chunk.content()))) { + // This should be impossible according to the Reactive Streams spec, + // if it happens then it's a bug in the implementation of the subscriber of publisher + chunk.release() + synchronized(subscription.cancel()) + } else if (closed) { + clearQueue() + } + } + + override def onError(t: Throwable): Unit = + if (!closed) { + queue.offer(Error(t)) + } + + override def onComplete(): Unit = + if (!closed) { + queue.offer(End) + } +} +private[netty] object SubscriberInputStream { + private sealed abstract class Item + private case class Chunk(data: ByteBuf) extends Item + private case class Error(cause: Throwable) extends Item + private object End extends Item + + def processAsStream( + publisher: Publisher[HttpContent], + contentLength: Option[Long], + maxBytes: Option[Long], + maxBufferedChunks: Int = 1 + ): InputStream = maxBytes match { + case Some(max) if contentLength.exists(_ > max) => + throw StreamMaxLengthExceededException(max) + case _ => + val subscriber = new SubscriberInputStream(maxBufferedChunks) + val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber) + publisher.subscribe(maybeLimitedSubscriber) + subscriber + } +} diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala new file mode 100644 index 0000000000..f1f4364681 --- /dev/null +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala @@ -0,0 +1,95 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import cats.effect.IO +import cats.effect.kernel.Resource +import cats.effect.unsafe.IORuntime +import fs2.Stream +import fs2.interop.reactivestreams._ +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.DefaultHttpContent +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.should.Matchers +import org.scalactic.source.Position + +import scala.util.Random + +class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { + private implicit def runtime: IORuntime = IORuntime.global + + private def testReading( + size: Int, + chunkLimit: Int = 1024, + maxBufferedChunks: Int = 1 + )(implicit pos: Position): Unit = { + val bytes = new Array[Byte](size) + Random.nextBytes(bytes) + + val publisherResource = Stream + .emits(bytes) + .chunkLimit(chunkLimit) + .map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) + .covary[IO] + .toUnicastPublisher + + val readBytes = publisherResource.use { publisher => + IO { + val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks) + publisher.subscribe(subscriberInputStream) + subscriberInputStream.readAllBytes() + } + } + + readBytes.unsafeRunSync() shouldBe bytes + () + } + + "empty stream" in { + testReading(0) + } + + "single chunk stream" in { + testReading(10) + } + + "multiple chunks" in { + testReading(100, 10) + } + + "multiple chunks with larger buffer" in { + testReading(100, 10, maxBufferedChunks = 5) + } + + "multiple chunks with smaller last chunk" in { + testReading(105, 10) + } + + "closing the stream should cancel the subscription" in { + var canceled = false + + val publisherResource = + Stream + .emits(Array.fill(1024)(0.toByte)) + .chunkLimit(100) + .map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) + .covary[IO] + .onFinalizeCase { + case Resource.ExitCase.Canceled => IO { canceled = true } + case _ => IO.unit + } + .toUnicastPublisher + + publisherResource + .use(publisher => + IO { + val stream = new SubscriberInputStream() + publisher.subscribe(stream) + + stream.readNBytes(120).length shouldBe 120 + stream.close() + } + ) + .unsafeRunSync() + + canceled shouldBe true + } +} From 1884bd10c67ce8d6c6be0f78cef1ea9e0a663219 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Mon, 25 Mar 2024 18:11:09 +0100 Subject: [PATCH 09/16] removed bad import --- .../sttp/tapir/server/netty/NettyFutureServerInterpreter.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala index 59db8c796e..7800066b0d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServerInterpreter.scala @@ -2,7 +2,6 @@ package sttp.tapir.server.netty import sttp.monad.FutureMonad import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync} import scala.concurrent.{ExecutionContext, Future} From cc4a6322952e3be7111983933c81b406bc98e8eb Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Tue, 26 Mar 2024 11:44:40 +0100 Subject: [PATCH 10/16] proper handling of blocking code in input stream tests --- .../zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala | 1 + .../scala/sttp/tapir/server/tests/ServerBasicTests.scala | 4 ++-- .../src/main/scala/sttp/tapir/server/tests/package.scala | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala b/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala index 00c4432f95..d517b5ab7b 100644 --- a/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala +++ b/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala @@ -13,4 +13,5 @@ class RIOMonadError[R] extends MonadError[RIO[R, *]] { override def suspend[T](t: => RIO[R, T]): RIO[R, T] = ZIO.suspend(t) override def flatten[T](ffa: RIO[R, RIO[R, T]]): RIO[R, T] = ffa.flatten override def ensure[T](f: RIO[R, T], e: => RIO[R, Unit]): RIO[R, T] = f.ensuring(e.ignore) + override def blocking[T](t: => T): RIO[R, T] = ZIO.attemptBlocking(t) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 000abfedb1..672481a278 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -735,7 +735,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( def inputStreamTests(): List[Test] = List( testServer(in_input_stream_out_input_stream)((is: InputStream) => - pureResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) + blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) }, testServer(in_string_out_stream_with_header)(_ => pureResult(Right((new ByteArrayInputStream(Array.fill[Byte](128)(0)), Some(128))))) { (backend, baseUri) => @@ -795,7 +795,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( "checks payload limit and returns OK on content length below or equal max (request)" )(i => { // Forcing server logic to drain the InputStream - suspendResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit]) + blockingResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit]) }) { (backend, baseUri) => val tooLargeBody: String = List.fill(maxLength)('x').mkString basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).response(asByteArray).send(backend).map { r => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala index 8fd69d9836..c98dc61838 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala @@ -10,6 +10,7 @@ import sttp.monad.MonadError package object tests { val backendResource: Resource[IO, SttpBackend[IO, Fs2Streams[IO] with WebSockets]] = HttpClientFs2Backend.resource() val basicStringRequest: PartialRequest[String, Any] = basicRequest.response(asStringAlways) - def pureResult[F[_]: MonadError, T](t: T): F[T] = implicitly[MonadError[F]].unit(t) - def suspendResult[F[_]: MonadError, T](t: => T): F[T] = implicitly[MonadError[F]].eval(t) + def pureResult[F[_]: MonadError, T](t: T): F[T] = MonadError[F].unit(t) + def suspendResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].eval(t) + def blockingResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].blocking(t) } From a07598ce44ce726957bae7e83ed7da00c5df9521 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Tue, 26 Mar 2024 11:52:25 +0100 Subject: [PATCH 11/16] added proper `blocking` impl to all MonadError instances --- .../src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala | 2 ++ .../test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala | 2 ++ .../tapir/server/vertx/cats/VertxCatsServerInterpreter.scala | 1 + 3 files changed, 5 insertions(+) diff --git a/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala index 9f24cc8d73..e09d1ce90f 100644 --- a/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala +++ b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala @@ -21,6 +21,8 @@ trait MonadErrorSyntax { }) override def ensure[T](f: G[T], e: => G[Unit]): G[T] = fk(mef.ensure(gK(f), gK(e))) + + override def blocking[T](t: => T): G[T] = fk(mef.blocking(t)) } } } diff --git a/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala b/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala index 193da4bb10..7c32c40967 100644 --- a/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala +++ b/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala @@ -19,5 +19,7 @@ object TestMonadError { rt.catchSome(h) override def ensure[T](f: TestEffect[T], e: => TestEffect[Unit]): TestEffect[T] = f.ensuring(e.ignore) + + override def blocking[T](t: => T): TestEffect[T] = ZIO.attemptBlocking(t) } } diff --git a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala index e9f0d8d6d2..9bc6dedfc5 100644 --- a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala +++ b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala @@ -117,6 +117,7 @@ object VertxCatsServerInterpreter { override def suspend[T](t: => F[T]): F[T] = F.defer(t) override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa) override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guaranteeCase(f)(_ => e) + override def blocking[T](t: => T): F[T] = F.blocking(t) } private[cats] class CatsFFromVFuture[F[_]: Async] extends FromVFuture[F] { From 4c039998c8616ba50e30c552ea489d110015a854 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Tue, 26 Mar 2024 12:29:21 +0100 Subject: [PATCH 12/16] more blocking code management in tests --- .../main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala index eef3338891..6eb073ef2f 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala @@ -65,7 +65,7 @@ class ServerMetricsTest[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest testServer( in_input_stream_out_input_stream.name("metrics"), interceptors = (ci: CustomiseInterceptors[F, OPTIONS]) => ci.metricsInterceptor(metrics) - )(is => (new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit].unit) { (backend, baseUri) => + )(is => blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("okoĊ„") From a4fec4b9bf22e0e3f7f3b70bd6d0281dd53e2c39 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Tue, 26 Mar 2024 13:03:18 +0100 Subject: [PATCH 13/16] cosmetic --- .../reactivestreams/SubscriberInputStreamTest.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala index f1f4364681..8fce1a0ff7 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala @@ -31,16 +31,15 @@ class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { .covary[IO] .toUnicastPublisher - val readBytes = publisherResource.use { publisher => + val io = publisherResource.use { publisher => IO { val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks) publisher.subscribe(subscriberInputStream) - subscriberInputStream.readAllBytes() + subscriberInputStream.readAllBytes() shouldBe bytes } } - readBytes.unsafeRunSync() shouldBe bytes - () + io.unsafeRunSync() } "empty stream" in { From f0a91603dc58374f0e0cd36ecc1f6b41d653b550 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Wed, 27 Mar 2024 08:47:13 +0100 Subject: [PATCH 14/16] replaced `synchronized` usages with a lock --- .../SubscriberInputStream.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala index 62a99c9539..66c7f21ea5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala @@ -7,6 +7,7 @@ import sttp.capabilities.StreamMaxLengthExceededException import java.io.{IOException, InputStream} import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.locks.ReentrantLock import scala.annotation.tailrec import scala.concurrent.Promise @@ -14,19 +15,29 @@ import scala.concurrent.Promise * @param maxBufferedChunks * maximum number of unread chunks that can be buffered before blocking the publisher */ -private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) - extends InputStream with Subscriber[HttpContent] { - +private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) extends InputStream with Subscriber[HttpContent] { + require(maxBufferedChunks > 0) import SubscriberInputStream._ // volatile because used in both InputStream & Subscriber methods @volatile private[this] var closed = false + // Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec // (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7) // because they are called both from InputStream & Subscriber methods. private[this] var subscription: Subscription = _ + private[this] val lock = new ReentrantLock + + private def locked[T](code: => T): T = + try { + lock.lock() + code + } finally { + lock.unlock() + } + private[this] var currentItem: Item = _ // the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error @@ -35,7 +46,7 @@ private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) if (currentItem eq null) { currentItem = if (blocking) queue.take() else queue.poll() currentItem match { - case _: Chunk => synchronized(subscription.request(1)) + case _: Chunk => locked(subscription.request(1)) case _ => } } @@ -71,7 +82,7 @@ private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) } override def close(): Unit = if (!closed) { - synchronized(subscription.cancel()) + locked(subscription.cancel()) closed = true clearQueue() } @@ -84,7 +95,7 @@ private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) case _ => } - override def onSubscribe(s: Subscription): Unit = synchronized { + override def onSubscribe(s: Subscription): Unit = locked { if (s eq null) { throw new NullPointerException("Subscription must not be null") } @@ -97,7 +108,7 @@ private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) // This should be impossible according to the Reactive Streams spec, // if it happens then it's a bug in the implementation of the subscriber of publisher chunk.release() - synchronized(subscription.cancel()) + locked(subscription.cancel()) } else if (closed) { clearQueue() } From f59d611400fddc035242d77b2158eea52b91b002 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Wed, 27 Mar 2024 09:13:44 +0100 Subject: [PATCH 15/16] SubscriberInputStreamTest improvements --- .../SubscriberInputStreamTest.scala | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala index 8fce1a0ff7..602287f4d0 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala @@ -7,26 +7,38 @@ import fs2.Stream import fs2.interop.reactivestreams._ import io.netty.buffer.Unpooled import io.netty.handler.codec.http.DefaultHttpContent +import org.scalactic.source.Position import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers -import org.scalactic.source.Position +import java.io.InputStream +import scala.annotation.tailrec import scala.util.Random class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { private implicit def runtime: IORuntime = IORuntime.global + private def readAll(is: InputStream, batchSize: Int): Array[Byte] = { + val buf = Unpooled.buffer(batchSize) + @tailrec def writeLoop(): Array[Byte] = buf.writeBytes(is, batchSize) match { + case -1 => buf.array().take(buf.readableBytes()) + case _ => writeLoop() + } + writeLoop() + } + private def testReading( - size: Int, - chunkLimit: Int = 1024, + totalSize: Int, + publishedChunkLimit: Int, + readBatchSize: Int, maxBufferedChunks: Int = 1 )(implicit pos: Position): Unit = { - val bytes = new Array[Byte](size) + val bytes = new Array[Byte](totalSize) Random.nextBytes(bytes) val publisherResource = Stream .emits(bytes) - .chunkLimit(chunkLimit) + .chunkLimit(publishedChunkLimit) .map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) .covary[IO] .toUnicastPublisher @@ -35,7 +47,7 @@ class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { IO { val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks) publisher.subscribe(subscriberInputStream) - subscriberInputStream.readAllBytes() shouldBe bytes + readAll(subscriberInputStream, readBatchSize) shouldBe bytes } } @@ -43,23 +55,29 @@ class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { } "empty stream" in { - testReading(0) + testReading(totalSize = 0, publishedChunkLimit = 1024, readBatchSize = 1024) + } + + "single chunk stream, one read batch" in { + testReading(totalSize = 10, publishedChunkLimit = 1024, readBatchSize = 1024) } - "single chunk stream" in { - testReading(10) + "single chunk stream, multiple read batches" in { + testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 10) + testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 11) } - "multiple chunks" in { - testReading(100, 10) + "multiple chunks, read batch larger than chunk" in { + testReading(totalSize = 100, publishedChunkLimit = 10, readBatchSize = 1024) + testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024) } - "multiple chunks with larger buffer" in { - testReading(100, 10, maxBufferedChunks = 5) + "multiple chunks, read batch smaller than chunk" in { + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 17) } - "multiple chunks with smaller last chunk" in { - testReading(105, 10) + "multiple chunks, large publishing buffer" in { + testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024, maxBufferedChunks = 5) } "closing the stream should cancel the subscription" in { From 958942449e40f64871a012ddef00d68216b1eb06 Mon Sep 17 00:00:00 2001 From: Roman Janusz Date: Wed, 27 Mar 2024 09:15:51 +0100 Subject: [PATCH 16/16] SubscriberInputStreamTest improvements --- .../internal/reactivestreams/SubscriberInputStreamTest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala index 602287f4d0..26f25bdd56 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala @@ -74,6 +74,8 @@ class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { "multiple chunks, read batch smaller than chunk" in { testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 17) + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 7) + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 5) } "multiple chunks, large publishing buffer" in {