diff --git a/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala index 9f24cc8d73..e09d1ce90f 100644 --- a/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala +++ b/integrations/cats/src/main/scala/sttp/tapir/integ/cats/MonadErrorSyntax.scala @@ -21,6 +21,8 @@ trait MonadErrorSyntax { }) override def ensure[T](f: G[T], e: => G[Unit]): G[T] = fk(mef.ensure(gK(f), gK(e))) + + override def blocking[T](t: => T): G[T] = fk(mef.blocking(t)) } } } diff --git a/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala b/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala index 00c4432f95..d517b5ab7b 100644 --- a/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala +++ b/integrations/zio/src/main/scala/sttp/tapir/ztapir/RIOMonadError.scala @@ -13,4 +13,5 @@ class RIOMonadError[R] extends MonadError[RIO[R, *]] { override def suspend[T](t: => RIO[R, T]): RIO[R, T] = ZIO.suspend(t) override def flatten[T](ffa: RIO[R, RIO[R, T]]): RIO[R, T] = ffa.flatten override def ensure[T](f: RIO[R, T], e: => RIO[R, Unit]): RIO[R, T] = f.ensuring(e.ignore) + override def blocking[T](t: => T): RIO[R, T] = ZIO.attemptBlocking(t) } diff --git a/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala b/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala index 193da4bb10..7c32c40967 100644 --- a/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala +++ b/integrations/zio/src/test/scala/sttp/tapir/ztapir/instances/TestMonadError.scala @@ -19,5 +19,7 @@ object TestMonadError { rt.catchSome(h) override def ensure[T](f: TestEffect[T], e: => TestEffect[Unit]): TestEffect[T] = f.ensuring(e.ignore) + + override def blocking[T](t: => T): TestEffect[T] = ZIO.attemptBlocking(t) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 78f01c0ea0..f7cd386dbb 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -4,19 +4,17 @@ import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher +import sttp.capabilities.Streams +import sttp.model.HeaderNames import sttp.monad.MonadError import sttp.monad.syntax._ +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.RequestBody -import sttp.tapir.RawBodyType -import sttp.tapir.TapirFile -import sttp.tapir.server.interpreter.RawValue -import sttp.tapir.FileRange -import sttp.tapir.InputStreamRange -import java.io.ByteArrayInputStream +import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.server.netty.internal.reactivestreams.SubscriberInputStream + +import java.io.{ByteArrayInputStream, InputStream} import java.nio.ByteBuffer -import sttp.capabilities.Streams -import sttp.model.HeaderNames /** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { @@ -64,17 +62,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case RawBodyType.ByteBufferBody => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs))) case RawBodyType.InputStreamBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs))) + monad.eval(RawValue(readAsStream(serverRequest, maxBytes))) case RawBodyType.InputStreamRangeBody => - // Possibly can be optimized to avoid loading all data eagerly into memory - readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + monad.unit(RawValue(InputStreamRange(() => readAsStream(serverRequest, maxBytes)))) case RawBodyType.FileBody => for { file <- createFile(serverRequest) _ <- writeToFile(serverRequest, file, maxBytes) } yield RawValue(FileRange(file), Seq(FileRange(file))) - case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException) + case _: RawBodyType.MultipartBody => + monad.error(new UnsupportedOperationException) } private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = @@ -84,7 +81,19 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case req: StreamedHttpRequest => val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong) publisherToBytes(req, contentLength, maxBytes) - case other => + case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")) } + + private def readAsStream(serverRequest: ServerRequest, maxBytes: Option[Long]): InputStream = { + serverRequest.underlying match { + case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request + InputStream.nullInputStream() + case req: StreamedHttpRequest => + val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong) + SubscriberInputStream.processAsStream(req, contentLength, maxBytes) + case other => + throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}") + } + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala new file mode 100644 index 0000000000..66c7f21ea5 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStream.scala @@ -0,0 +1,147 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import io.netty.buffer.ByteBuf +import io.netty.handler.codec.http.HttpContent +import org.reactivestreams.{Publisher, Subscriber, Subscription} +import sttp.capabilities.StreamMaxLengthExceededException + +import java.io.{IOException, InputStream} +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.locks.ReentrantLock +import scala.annotation.tailrec +import scala.concurrent.Promise + +/** A blocking input stream that reads from a reactive streams publisher of [[HttpContent]]. + * @param maxBufferedChunks + * maximum number of unread chunks that can be buffered before blocking the publisher + */ +private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1) extends InputStream with Subscriber[HttpContent] { + + require(maxBufferedChunks > 0) + + import SubscriberInputStream._ + + // volatile because used in both InputStream & Subscriber methods + @volatile private[this] var closed = false + + // Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec + // (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7) + // because they are called both from InputStream & Subscriber methods. + private[this] var subscription: Subscription = _ + private[this] val lock = new ReentrantLock + + private def locked[T](code: => T): T = + try { + lock.lock() + code + } finally { + lock.unlock() + } + + private[this] var currentItem: Item = _ + // the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher + private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error + + private def readItem(blocking: Boolean): Item = { + if (currentItem eq null) { + currentItem = if (blocking) queue.take() else queue.poll() + currentItem match { + case _: Chunk => locked(subscription.request(1)) + case _ => + } + } + currentItem + } + + override def available(): Int = + readItem(blocking = false) match { + case Chunk(data) => data.readableBytes() + case _ => 0 + } + + override def read(): Int = { + val buffer = new Array[Byte](1) + if (read(buffer) == -1) -1 else buffer(0) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = + if (closed) throw new IOException("Stream closed") + else if (len == 0) 0 + else + readItem(blocking = true) match { + case Chunk(data) => + val toRead = Math.min(len, data.readableBytes()) + data.readBytes(b, off, toRead) + if (data.readableBytes() == 0) { + data.release() + currentItem = null + } + toRead + case Error(cause) => throw cause + case End => -1 + } + + override def close(): Unit = if (!closed) { + locked(subscription.cancel()) + closed = true + clearQueue() + } + + @tailrec private def clearQueue(): Unit = + queue.poll() match { + case Chunk(data) => + data.release() + clearQueue() + case _ => + } + + override def onSubscribe(s: Subscription): Unit = locked { + if (s eq null) { + throw new NullPointerException("Subscription must not be null") + } + subscription = s + subscription.request(maxBufferedChunks) + } + + override def onNext(chunk: HttpContent): Unit = { + if (!queue.offer(Chunk(chunk.content()))) { + // This should be impossible according to the Reactive Streams spec, + // if it happens then it's a bug in the implementation of the subscriber of publisher + chunk.release() + locked(subscription.cancel()) + } else if (closed) { + clearQueue() + } + } + + override def onError(t: Throwable): Unit = + if (!closed) { + queue.offer(Error(t)) + } + + override def onComplete(): Unit = + if (!closed) { + queue.offer(End) + } +} +private[netty] object SubscriberInputStream { + private sealed abstract class Item + private case class Chunk(data: ByteBuf) extends Item + private case class Error(cause: Throwable) extends Item + private object End extends Item + + def processAsStream( + publisher: Publisher[HttpContent], + contentLength: Option[Long], + maxBytes: Option[Long], + maxBufferedChunks: Int = 1 + ): InputStream = maxBytes match { + case Some(max) if contentLength.exists(_ > max) => + throw StreamMaxLengthExceededException(max) + case _ => + val subscriber = new SubscriberInputStream(maxBufferedChunks) + val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber) + publisher.subscribe(maybeLimitedSubscriber) + subscriber + } +} diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala new file mode 100644 index 0000000000..26f25bdd56 --- /dev/null +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/internal/reactivestreams/SubscriberInputStreamTest.scala @@ -0,0 +1,114 @@ +package sttp.tapir.server.netty.internal.reactivestreams + +import cats.effect.IO +import cats.effect.kernel.Resource +import cats.effect.unsafe.IORuntime +import fs2.Stream +import fs2.interop.reactivestreams._ +import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.DefaultHttpContent +import org.scalactic.source.Position +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.should.Matchers + +import java.io.InputStream +import scala.annotation.tailrec +import scala.util.Random + +class SubscriberInputStreamTest extends AnyFreeSpec with Matchers { + private implicit def runtime: IORuntime = IORuntime.global + + private def readAll(is: InputStream, batchSize: Int): Array[Byte] = { + val buf = Unpooled.buffer(batchSize) + @tailrec def writeLoop(): Array[Byte] = buf.writeBytes(is, batchSize) match { + case -1 => buf.array().take(buf.readableBytes()) + case _ => writeLoop() + } + writeLoop() + } + + private def testReading( + totalSize: Int, + publishedChunkLimit: Int, + readBatchSize: Int, + maxBufferedChunks: Int = 1 + )(implicit pos: Position): Unit = { + val bytes = new Array[Byte](totalSize) + Random.nextBytes(bytes) + + val publisherResource = Stream + .emits(bytes) + .chunkLimit(publishedChunkLimit) + .map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) + .covary[IO] + .toUnicastPublisher + + val io = publisherResource.use { publisher => + IO { + val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks) + publisher.subscribe(subscriberInputStream) + readAll(subscriberInputStream, readBatchSize) shouldBe bytes + } + } + + io.unsafeRunSync() + } + + "empty stream" in { + testReading(totalSize = 0, publishedChunkLimit = 1024, readBatchSize = 1024) + } + + "single chunk stream, one read batch" in { + testReading(totalSize = 10, publishedChunkLimit = 1024, readBatchSize = 1024) + } + + "single chunk stream, multiple read batches" in { + testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 10) + testReading(totalSize = 100, publishedChunkLimit = 1024, readBatchSize = 11) + } + + "multiple chunks, read batch larger than chunk" in { + testReading(totalSize = 100, publishedChunkLimit = 10, readBatchSize = 1024) + testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024) + } + + "multiple chunks, read batch smaller than chunk" in { + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 17) + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 7) + testReading(totalSize = 105, publishedChunkLimit = 20, readBatchSize = 5) + } + + "multiple chunks, large publishing buffer" in { + testReading(totalSize = 105, publishedChunkLimit = 10, readBatchSize = 1024, maxBufferedChunks = 5) + } + + "closing the stream should cancel the subscription" in { + var canceled = false + + val publisherResource = + Stream + .emits(Array.fill(1024)(0.toByte)) + .chunkLimit(100) + .map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer))) + .covary[IO] + .onFinalizeCase { + case Resource.ExitCase.Canceled => IO { canceled = true } + case _ => IO.unit + } + .toUnicastPublisher + + publisherResource + .use(publisher => + IO { + val stream = new SubscriberInputStream() + publisher.subscribe(stream) + + stream.readNBytes(120).length shouldBe 120 + stream.close() + } + ) + .unsafeRunSync() + + canceled shouldBe true + } +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 000abfedb1..672481a278 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -735,7 +735,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( def inputStreamTests(): List[Test] = List( testServer(in_input_stream_out_input_stream)((is: InputStream) => - pureResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) + blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit]) ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) }, testServer(in_string_out_stream_with_header)(_ => pureResult(Right((new ByteArrayInputStream(Array.fill[Byte](128)(0)), Some(128))))) { (backend, baseUri) => @@ -795,7 +795,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( "checks payload limit and returns OK on content length below or equal max (request)" )(i => { // Forcing server logic to drain the InputStream - suspendResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit]) + blockingResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit]) }) { (backend, baseUri) => val tooLargeBody: String = List.fill(maxLength)('x').mkString basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).response(asByteArray).send(backend).map { r => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala index eef3338891..6eb073ef2f 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMetricsTest.scala @@ -65,7 +65,7 @@ class ServerMetricsTest[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest testServer( in_input_stream_out_input_stream.name("metrics"), interceptors = (ci: CustomiseInterceptors[F, OPTIONS]) => ci.metricsInterceptor(metrics) - )(is => (new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit].unit) { (backend, baseUri) => + )(is => blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])) { (backend, baseUri) => basicRequest .post(uri"$baseUri/api/echo") .body("okoĊ„") diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala index 8fd69d9836..c98dc61838 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/package.scala @@ -10,6 +10,7 @@ import sttp.monad.MonadError package object tests { val backendResource: Resource[IO, SttpBackend[IO, Fs2Streams[IO] with WebSockets]] = HttpClientFs2Backend.resource() val basicStringRequest: PartialRequest[String, Any] = basicRequest.response(asStringAlways) - def pureResult[F[_]: MonadError, T](t: T): F[T] = implicitly[MonadError[F]].unit(t) - def suspendResult[F[_]: MonadError, T](t: => T): F[T] = implicitly[MonadError[F]].eval(t) + def pureResult[F[_]: MonadError, T](t: T): F[T] = MonadError[F].unit(t) + def suspendResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].eval(t) + def blockingResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].blocking(t) } diff --git a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala index e9f0d8d6d2..9bc6dedfc5 100644 --- a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala +++ b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/VertxCatsServerInterpreter.scala @@ -117,6 +117,7 @@ object VertxCatsServerInterpreter { override def suspend[T](t: => F[T]): F[T] = F.defer(t) override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa) override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guaranteeCase(f)(_ => e) + override def blocking[T](t: => T): F[T] = F.blocking(t) } private[cats] class CatsFFromVFuture[F[_]: Async] extends FromVFuture[F] {