Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy InputStream over Netty HttpContent Publisher #3637

Merged
merged 18 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ trait MonadErrorSyntax {
})

override def ensure[T](f: G[T], e: => G[Unit]): G[T] = fk(mef.ensure(gK(f), gK(e)))

override def blocking[T](t: => T): G[T] = fk(mef.blocking(t))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ class RIOMonadError[R] extends MonadError[RIO[R, *]] {
override def suspend[T](t: => RIO[R, T]): RIO[R, T] = ZIO.suspend(t)
override def flatten[T](ffa: RIO[R, RIO[R, T]]): RIO[R, T] = ffa.flatten
override def ensure[T](f: RIO[R, T], e: => RIO[R, Unit]): RIO[R, T] = f.ensuring(e.ignore)
override def blocking[T](t: => T): RIO[R, T] = ZIO.attemptBlocking(t)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@ object TestMonadError {
rt.catchSome(h)

override def ensure[T](f: TestEffect[T], e: => TestEffect[Unit]): TestEffect[T] = f.ensuring(e.ignore)

override def blocking[T](t: => T): TestEffect[T] = ZIO.attemptBlocking(t)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@ import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{FullHttpRequest, HttpContent}
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.Streams
import sttp.model.HeaderNames
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RequestBody
import sttp.tapir.RawBodyType
import sttp.tapir.TapirFile
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.FileRange
import sttp.tapir.InputStreamRange
import java.io.ByteArrayInputStream
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.server.netty.internal.reactivestreams.SubscriberInputStream

import java.io.{ByteArrayInputStream, InputStream}
import java.nio.ByteBuffer
import sttp.capabilities.Streams
import sttp.model.HeaderNames

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {
Expand Down Expand Up @@ -64,17 +62,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
monad.eval(RawValue(readAsStream(serverRequest, maxBytes)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
monad.unit(RawValue(InputStreamRange(() => readAsStream(serverRequest, maxBytes))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException)
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
Expand All @@ -84,7 +81,19 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
publisherToBytes(req, contentLength, maxBytes)
case other =>
case other =>
monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}"))
}

private def readAsStream(serverRequest: ServerRequest, maxBytes: Option[Long]): InputStream = {
serverRequest.underlying match {
case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request
InputStream.nullInputStream()
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
SubscriberInputStream.processAsStream(req, contentLength, maxBytes)
case other =>
throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package sttp.tapir.server.netty.internal.reactivestreams

import io.netty.buffer.ByteBuf
import io.netty.handler.codec.http.HttpContent
import org.reactivestreams.{Publisher, Subscriber, Subscription}
import sttp.capabilities.StreamMaxLengthExceededException

import java.io.{IOException, InputStream}
import java.util.concurrent.LinkedBlockingQueue
import scala.annotation.tailrec
import scala.concurrent.Promise

/** A blocking input stream that reads from a reactive streams publisher of [[HttpContent]].
* @param maxBufferedChunks
* maximum number of unread chunks that can be buffered before blocking the publisher
*/
private[netty] class SubscriberInputStream(maxBufferedChunks: Int = 1)
extends InputStream with Subscriber[HttpContent] {

require(maxBufferedChunks > 0)

import SubscriberInputStream._

// volatile because used in both InputStream & Subscriber methods
@volatile private[this] var closed = false
// Calls on the subscription must be synchronized in order to satisfy the Reactive Streams spec
// (https://github.com/reactive-streams/reactive-streams-jvm?tab=readme-ov-file#2-subscriber-code - rule 7)
// because they are called both from InputStream & Subscriber methods.
private[this] var subscription: Subscription = _
private[this] var currentItem: Item = _
// the queue serves as a buffer to allow for possible parallelism between the subscriber and the publisher
private val queue = new LinkedBlockingQueue[Item](maxBufferedChunks + 1) // +1 to have a spot for End/Error

private def readItem(blocking: Boolean): Item = {
if (currentItem eq null) {
currentItem = if (blocking) queue.take() else queue.poll()
currentItem match {
case _: Chunk => synchronized(subscription.request(1))
ghik marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
}
}
currentItem
}

override def available(): Int =
readItem(blocking = false) match {
case Chunk(data) => data.readableBytes()
case _ => 0
}

override def read(): Int = {
val buffer = new Array[Byte](1)
if (read(buffer) == -1) -1 else buffer(0)
}

override def read(b: Array[Byte], off: Int, len: Int): Int =
if (closed) throw new IOException("Stream closed")
else if (len == 0) 0
else
readItem(blocking = true) match {
case Chunk(data) =>
val toRead = Math.min(len, data.readableBytes())
data.readBytes(b, off, toRead)
if (data.readableBytes() == 0) {
data.release()
currentItem = null
}
toRead
case Error(cause) => throw cause
case End => -1
}

override def close(): Unit = if (!closed) {
synchronized(subscription.cancel())
closed = true
clearQueue()
}

@tailrec private def clearQueue(): Unit =
queue.poll() match {
case Chunk(data) =>
data.release()
clearQueue()
case _ =>
}

override def onSubscribe(s: Subscription): Unit = synchronized {
if (s eq null) {
throw new NullPointerException("Subscription must not be null")
}
subscription = s
subscription.request(maxBufferedChunks)
}

override def onNext(chunk: HttpContent): Unit = {
if (!queue.offer(Chunk(chunk.content()))) {
// This should be impossible according to the Reactive Streams spec,
// if it happens then it's a bug in the implementation of the subscriber of publisher
chunk.release()
synchronized(subscription.cancel())
} else if (closed) {
clearQueue()
}
}

override def onError(t: Throwable): Unit =
if (!closed) {
queue.offer(Error(t))
}

override def onComplete(): Unit =
if (!closed) {
queue.offer(End)
}
}
private[netty] object SubscriberInputStream {
private sealed abstract class Item
private case class Chunk(data: ByteBuf) extends Item
private case class Error(cause: Throwable) extends Item
private object End extends Item

def processAsStream(
publisher: Publisher[HttpContent],
contentLength: Option[Long],
maxBytes: Option[Long],
maxBufferedChunks: Int = 1
): InputStream = maxBytes match {
case Some(max) if contentLength.exists(_ > max) =>
throw StreamMaxLengthExceededException(max)
case _ =>
val subscriber = new SubscriberInputStream(maxBufferedChunks)
val maybeLimitedSubscriber = maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber)
publisher.subscribe(maybeLimitedSubscriber)
subscriber
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sttp.tapir.server.netty.internal.reactivestreams

import cats.effect.IO
import cats.effect.kernel.Resource
import cats.effect.unsafe.IORuntime
import fs2.Stream
import fs2.interop.reactivestreams._
import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.DefaultHttpContent
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
import org.scalactic.source.Position

import scala.util.Random

class SubscriberInputStreamTest extends AnyFreeSpec with Matchers {
private implicit def runtime: IORuntime = IORuntime.global

private def testReading(
size: Int,
chunkLimit: Int = 1024,
maxBufferedChunks: Int = 1
)(implicit pos: Position): Unit = {
val bytes = new Array[Byte](size)
Random.nextBytes(bytes)

val publisherResource = Stream
.emits(bytes)
.chunkLimit(chunkLimit)
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer)))
.covary[IO]
.toUnicastPublisher

val io = publisherResource.use { publisher =>
IO {
val subscriberInputStream = new SubscriberInputStream(maxBufferedChunks)
publisher.subscribe(subscriberInputStream)
subscriberInputStream.readAllBytes() shouldBe bytes
}
}

io.unsafeRunSync()
}

"empty stream" in {
testReading(0)
}

"single chunk stream" in {
testReading(10)
}

"multiple chunks" in {
testReading(100, 10)
}

"multiple chunks with larger buffer" in {
testReading(100, 10, maxBufferedChunks = 5)
}

"multiple chunks with smaller last chunk" in {
testReading(105, 10)
}

"closing the stream should cancel the subscription" in {
var canceled = false

val publisherResource =
Stream
.emits(Array.fill(1024)(0.toByte))
.chunkLimit(100)
.map(ch => new DefaultHttpContent(Unpooled.wrappedBuffer(ch.toByteBuffer)))
.covary[IO]
.onFinalizeCase {
case Resource.ExitCase.Canceled => IO { canceled = true }
case _ => IO.unit
}
.toUnicastPublisher

publisherResource
.use(publisher =>
IO {
val stream = new SubscriberInputStream()
publisher.subscribe(stream)

stream.readNBytes(120).length shouldBe 120
stream.close()
}
)
.unsafeRunSync()

canceled shouldBe true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](

def inputStreamTests(): List[Test] = List(
testServer(in_input_stream_out_input_stream)((is: InputStream) =>
pureResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])
blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])
) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("mango").send(backend).map(_.body shouldBe Right("mango")) },
testServer(in_string_out_stream_with_header)(_ => pureResult(Right((new ByteArrayInputStream(Array.fill[Byte](128)(0)), Some(128))))) {
(backend, baseUri) =>
Expand Down Expand Up @@ -795,7 +795,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
"checks payload limit and returns OK on content length below or equal max (request)"
)(i => {
// Forcing server logic to drain the InputStream
suspendResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit])
blockingResult(i.readAllBytes()).map(_ => new ByteArrayInputStream(Array.empty[Byte]).asRight[Unit])
}) { (backend, baseUri) =>
val tooLargeBody: String = List.fill(maxLength)('x').mkString
basicRequest.post(uri"$baseUri/api/echo").body(tooLargeBody).response(asByteArray).send(backend).map { r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ServerMetricsTest[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest
testServer(
in_input_stream_out_input_stream.name("metrics"),
interceptors = (ci: CustomiseInterceptors[F, OPTIONS]) => ci.metricsInterceptor(metrics)
)(is => (new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit].unit) { (backend, baseUri) =>
)(is => blockingResult((new ByteArrayInputStream(inputStreamToByteArray(is)): InputStream).asRight[Unit])) { (backend, baseUri) =>
basicRequest
.post(uri"$baseUri/api/echo")
.body("okoń")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.monad.MonadError
package object tests {
val backendResource: Resource[IO, SttpBackend[IO, Fs2Streams[IO] with WebSockets]] = HttpClientFs2Backend.resource()
val basicStringRequest: PartialRequest[String, Any] = basicRequest.response(asStringAlways)
def pureResult[F[_]: MonadError, T](t: T): F[T] = implicitly[MonadError[F]].unit(t)
def suspendResult[F[_]: MonadError, T](t: => T): F[T] = implicitly[MonadError[F]].eval(t)
def pureResult[F[_]: MonadError, T](t: T): F[T] = MonadError[F].unit(t)
def suspendResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].eval(t)
def blockingResult[F[_]: MonadError, T](t: => T): F[T] = MonadError[F].blocking(t)
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ object VertxCatsServerInterpreter {
override def suspend[T](t: => F[T]): F[T] = F.defer(t)
override def flatten[T](ffa: F[F[T]]): F[T] = F.flatten(ffa)
override def ensure[T](f: F[T], e: => F[Unit]): F[T] = F.guaranteeCase(f)(_ => e)
override def blocking[T](t: => T): F[T] = F.blocking(t)
}

private[cats] class CatsFFromVFuture[F[_]: Async] extends FromVFuture[F] {
Expand Down
Loading