From 7301c06f56191cf57d857085c3576da72140952a Mon Sep 17 00:00:00 2001 From: Krzysiek Ciesielski Date: Mon, 11 Mar 2024 11:01:09 +0100 Subject: [PATCH] Optimize SimpleSubscriber for Netty (#3583) --- .../cats/internal/NettyCatsRequestBody.scala | 2 +- .../netty/loom/NettyIdRequestBody.scala | 4 +- .../internal/NettyFutureRequestBody.scala | 4 +- .../netty/internal/NettyRequestBody.scala | 8 +- .../reactivestreams/SimpleSubscriber.scala | 118 ++++++++++++------ .../zio/internal/NettyZioRequestBody.scala | 6 +- 6 files changed, 99 insertions(+), 43 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index e1a762ae70..e81f66fd00 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -20,7 +20,7 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( override implicit val monad: MonadError[F] = new CatsMonadError() - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] = + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] = streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala index 5b1aaf8980..a14dccae8c 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdRequestBody.scala @@ -16,8 +16,8 @@ private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFi override implicit val monad: MonadError[Id] = idMonad override val streams: capabilities.Streams[NoStreams] = NoStreams - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = - SimpleSubscriber.processAllBlocking(publisher, maxBytes) + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] = + SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = serverRequest.underlying match { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index c6dcbf0a9d..41b4e023ca 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -18,8 +18,8 @@ private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Fut override val streams: capabilities.Streams[NoStreams] = NoStreams override implicit val monad: MonadError[Future] = new FutureMonad() - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = - SimpleSubscriber.processAll(publisher, maxBytes) + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] = + SimpleSubscriber.processAll(publisher, contentLength, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] = serverRequest.underlying match { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 9d1375e7a5..fc11b33db4 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -16,6 +16,7 @@ import sttp.tapir.InputStreamRange import java.io.ByteArrayInputStream import java.nio.ByteBuffer import sttp.capabilities.Streams +import sttp.model.HeaderNames /** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { @@ -29,12 +30,14 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody * * @param publisher * reactive publisher emitting byte chunks. + * @param contentLength + * Total content length, if known * @param maxBytes * optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]] * @return * An effect which finishes with a single array of all collected bytes. */ - def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] + def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file. * @@ -76,7 +79,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request monad.unit(Array.empty[Byte]) case req: StreamedHttpRequest => - publisherToBytes(req, maxBytes) + val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt) + publisherToBytes(req, contentLength, maxBytes) case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala index 5a2721c8e7..5408274822 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/reactivestreams/SimpleSubscriber.scala @@ -1,20 +1,19 @@ package sttp.tapir.server.netty.internal.reactivestreams -import io.netty.buffer.ByteBufUtil +import io.netty.buffer.{ByteBuf, ByteBufUtil} import io.netty.handler.codec.http.HttpContent import org.reactivestreams.{Publisher, Subscription} +import sttp.capabilities.StreamMaxLengthExceededException -import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.JavaConverters._ -import scala.concurrent.{Future, Promise} import java.util.concurrent.LinkedBlockingQueue +import scala.concurrent.{Future, Promise} -private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] { +private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends PromisingSubscriber[Array[Byte], HttpContent] { private var subscription: Subscription = _ - private val chunks = new ConcurrentLinkedQueue[Array[Byte]]() - private var size = 0 private val resultPromise = Promise[Array[Byte]]() - private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]]() + @volatile private var totalLength = 0 + private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]](1) + @volatile private var buffers = Vector[ByteBuf]() override def future: Future[Array[Byte]] = resultPromise.future def resultBlocking(): Either[Throwable, Array[Byte]] = resultBlockingQueue.take() @@ -25,44 +24,93 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], } override def onNext(content: HttpContent): Unit = { - val array = ByteBufUtil.getBytes(content.content()) - content.release() - size += array.length - chunks.add(array) - subscription.request(1) + val byteBuf = content.content() + // If expected content length is known, and we receive exactly this amount of bytes, we assume there's only one chunk and + // we can immediately return it without going through the buffer list. + if (contentLength.contains(byteBuf.readableBytes())) { + val finalArray = ByteBufUtil.getBytes(byteBuf) + byteBuf.release() + if (!resultBlockingQueue.offer(Right(finalArray))) { + // Queue full, which is unexpected. The previous chunk was supposed the be the only one. A malformed request perhaps? + subscription.cancel() + } else { + resultPromise.success(finalArray) + subscription.request(1) + } + } else { + buffers = buffers :+ byteBuf + totalLength += byteBuf.readableBytes() + subscription.request(1) + } } override def onError(t: Throwable): Unit = { - chunks.clear() - resultBlockingQueue.add(Left(t)) + buffers.foreach { buf => + val _ = buf.release() + } + buffers = Vector.empty + resultBlockingQueue.offer(Left(t)) resultPromise.failure(t) } override def onComplete(): Unit = { - val result = new Array[Byte](size) - val _ = chunks.asScala.foldLeft(0)((currentPosition, array) => { - System.arraycopy(array, 0, result, currentPosition, array.length) - currentPosition + array.length - }) - chunks.clear() - resultBlockingQueue.add(Right(result)) - resultPromise.success(result) + if (!buffers.isEmpty) { + val mergedArray = new Array[Byte](totalLength) + var currentIndex = 0 + buffers.foreach { buf => + val length = buf.readableBytes() + buf.getBytes(buf.readerIndex(), mergedArray, currentIndex, length) + currentIndex += length + val _ = buf.release() + } + buffers = Vector.empty + if (!resultBlockingQueue.offer(Right(mergedArray))) { + // Result queue full, which is unexpected. + resultPromise.failure(new IllegalStateException("Calling onComplete after result was already returned")) + } else + resultPromise.success(mergedArray) + } else { + () // result already sent in onNext + } } + } object SimpleSubscriber { - def processAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = { - val subscriber = new SimpleSubscriber() - publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) - subscriber.future - } - def processAllBlocking(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = { - val subscriber = new SimpleSubscriber() - publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber)) - subscriber.resultBlocking() match { - case Right(result) => result - case Left(e) => throw e + def processAll(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] = + maxBytes match { + case Some(max) if (contentLength.exists(_ > max)) => Future.failed(StreamMaxLengthExceededException(max)) + case Some(max) => { + val subscriber = new SimpleSubscriber(contentLength) + publisher.subscribe(new LimitedLengthSubscriber(max, subscriber)) + subscriber.future + } + case None => { + val subscriber = new SimpleSubscriber(contentLength) + publisher.subscribe(subscriber) + subscriber.future + } + } + + def processAllBlocking(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] = + maxBytes match { + case Some(max) if (contentLength.exists(_ > max)) => throw new StreamMaxLengthExceededException(max) + case Some(max) => { + val subscriber = new SimpleSubscriber(contentLength) + publisher.subscribe(new LimitedLengthSubscriber(max, subscriber)) + subscriber.resultBlocking() match { + case Right(result) => result + case Left(e) => throw e + } + } + case None => { + val subscriber = new SimpleSubscriber(contentLength) + publisher.subscribe(subscriber) + subscriber.resultBlocking() match { + case Right(result) => result + case Left(e) => throw e + } + } } - } } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala index 2e551cad81..d7fcedb1d1 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala @@ -19,7 +19,11 @@ private[zio] class NettyZioRequestBody[Env]( override val streams: ZioStreams = ZioStreams override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] - override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] = + override def publisherToBytes( + publisher: Publisher[HttpContent], + contentLength: Option[Int], + maxBytes: Option[Long] + ): RIO[Env, Array[Byte]] = streamCompatible.fromPublisher(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] =