Skip to content

Commit

Permalink
Optimize SimpleSubscriber for Netty (#3583)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Mar 11, 2024
1 parent a19b961 commit 7301c06
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFi
override implicit val monad: MonadError[Id] = idMonad
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, maxBytes)
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
serverRequest.underlying match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Fut
override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] =
SimpleSubscriber.processAll(publisher, maxBytes)
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] =
SimpleSubscriber.processAll(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] =
serverRequest.underlying match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import sttp.tapir.InputStreamRange
import java.io.ByteArrayInputStream
import java.nio.ByteBuffer
import sttp.capabilities.Streams
import sttp.model.HeaderNames

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {
Expand All @@ -29,12 +30,14 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*
* @param publisher
* reactive publisher emitting byte chunks.
* @param contentLength
* Total content length, if known
* @param maxBytes
* optional request length limit. If exceeded, The effect `F` is failed with a [[sttp.capabilities.StreamMaxLengthExceededException]]
* @return
* An effect which finishes with a single array of all collected bytes.
*/
def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): F[Array[Byte]]
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
Expand Down Expand Up @@ -76,7 +79,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request
monad.unit(Array.empty[Byte])
case req: StreamedHttpRequest =>
publisherToBytes(req, maxBytes)
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt)
publisherToBytes(req, contentLength, maxBytes)
case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}"))
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
package sttp.tapir.server.netty.internal.reactivestreams

import io.netty.buffer.ByteBufUtil
import io.netty.buffer.{ByteBuf, ByteBufUtil}
import io.netty.handler.codec.http.HttpContent
import org.reactivestreams.{Publisher, Subscription}
import sttp.capabilities.StreamMaxLengthExceededException

import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
import java.util.concurrent.LinkedBlockingQueue
import scala.concurrent.{Future, Promise}

private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte], HttpContent] {
private[netty] class SimpleSubscriber(contentLength: Option[Int]) extends PromisingSubscriber[Array[Byte], HttpContent] {
private var subscription: Subscription = _
private val chunks = new ConcurrentLinkedQueue[Array[Byte]]()
private var size = 0
private val resultPromise = Promise[Array[Byte]]()
private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]]()
@volatile private var totalLength = 0
private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Array[Byte]]](1)
@volatile private var buffers = Vector[ByteBuf]()

override def future: Future[Array[Byte]] = resultPromise.future
def resultBlocking(): Either[Throwable, Array[Byte]] = resultBlockingQueue.take()
Expand All @@ -25,44 +24,93 @@ private[netty] class SimpleSubscriber() extends PromisingSubscriber[Array[Byte],
}

override def onNext(content: HttpContent): Unit = {
val array = ByteBufUtil.getBytes(content.content())
content.release()
size += array.length
chunks.add(array)
subscription.request(1)
val byteBuf = content.content()
// If expected content length is known, and we receive exactly this amount of bytes, we assume there's only one chunk and
// we can immediately return it without going through the buffer list.
if (contentLength.contains(byteBuf.readableBytes())) {
val finalArray = ByteBufUtil.getBytes(byteBuf)
byteBuf.release()
if (!resultBlockingQueue.offer(Right(finalArray))) {
// Queue full, which is unexpected. The previous chunk was supposed the be the only one. A malformed request perhaps?
subscription.cancel()
} else {
resultPromise.success(finalArray)
subscription.request(1)
}
} else {
buffers = buffers :+ byteBuf
totalLength += byteBuf.readableBytes()
subscription.request(1)
}
}

override def onError(t: Throwable): Unit = {
chunks.clear()
resultBlockingQueue.add(Left(t))
buffers.foreach { buf =>
val _ = buf.release()
}
buffers = Vector.empty
resultBlockingQueue.offer(Left(t))
resultPromise.failure(t)
}

override def onComplete(): Unit = {
val result = new Array[Byte](size)
val _ = chunks.asScala.foldLeft(0)((currentPosition, array) => {
System.arraycopy(array, 0, result, currentPosition, array.length)
currentPosition + array.length
})
chunks.clear()
resultBlockingQueue.add(Right(result))
resultPromise.success(result)
if (!buffers.isEmpty) {
val mergedArray = new Array[Byte](totalLength)
var currentIndex = 0
buffers.foreach { buf =>
val length = buf.readableBytes()
buf.getBytes(buf.readerIndex(), mergedArray, currentIndex, length)
currentIndex += length
val _ = buf.release()
}
buffers = Vector.empty
if (!resultBlockingQueue.offer(Right(mergedArray))) {
// Result queue full, which is unexpected.
resultPromise.failure(new IllegalStateException("Calling onComplete after result was already returned"))
} else
resultPromise.success(mergedArray)
} else {
() // result already sent in onNext
}
}

}

object SimpleSubscriber {
def processAll(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Future[Array[Byte]] = {
val subscriber = new SimpleSubscriber()
publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber))
subscriber.future
}

def processAllBlocking(publisher: Publisher[HttpContent], maxBytes: Option[Long]): Array[Byte] = {
val subscriber = new SimpleSubscriber()
publisher.subscribe(maxBytes.map(max => new LimitedLengthSubscriber(max, subscriber)).getOrElse(subscriber))
subscriber.resultBlocking() match {
case Right(result) => result
case Left(e) => throw e
def processAll(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] =
maxBytes match {
case Some(max) if (contentLength.exists(_ > max)) => Future.failed(StreamMaxLengthExceededException(max))
case Some(max) => {
val subscriber = new SimpleSubscriber(contentLength)
publisher.subscribe(new LimitedLengthSubscriber(max, subscriber))
subscriber.future
}
case None => {
val subscriber = new SimpleSubscriber(contentLength)
publisher.subscribe(subscriber)
subscriber.future
}
}

def processAllBlocking(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] =
maxBytes match {
case Some(max) if (contentLength.exists(_ > max)) => throw new StreamMaxLengthExceededException(max)
case Some(max) => {
val subscriber = new SimpleSubscriber(contentLength)
publisher.subscribe(new LimitedLengthSubscriber(max, subscriber))
subscriber.resultBlocking() match {
case Right(result) => result
case Left(e) => throw e
}
}
case None => {
val subscriber = new SimpleSubscriber(contentLength)
publisher.subscribe(subscriber)
subscriber.resultBlocking() match {
case Right(result) => result
case Left(e) => throw e
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ private[zio] class NettyZioRequestBody[Env](
override val streams: ZioStreams = ZioStreams
override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env]

override def publisherToBytes(publisher: Publisher[HttpContent], maxBytes: Option[Long]): RIO[Env, Array[Byte]] =
override def publisherToBytes(
publisher: Publisher[HttpContent],
contentLength: Option[Int],
maxBytes: Option[Long]
): RIO[Env, Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).run(ZSink.collectAll[Byte]).map(_.toArray)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): RIO[Env, Unit] =
Expand Down

0 comments on commit 7301c06

Please sign in to comment.