Skip to content

Commit

Permalink
MaxContentLength for Pekko and Akka, Play, Vertx (#3375)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Dec 12, 2023
1 parent 15fb9fd commit 716ce3a
Show file tree
Hide file tree
Showing 50 changed files with 369 additions and 898 deletions.
4 changes: 2 additions & 2 deletions doc/endpoint/security.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ Optional and multiple authentication inputs have some additional rules as to how

## Limiting request body length

*Supported backends*:
This feature is available for backends based on http4s, jdkhttp, Netty, and Play. More backends will be added in the near future.
*Unsupported backends*:
This feature is available for all server backends *except*: `akka-grpc`, `Armeria`, `Finatra`, `Helidon Nima`, `pekko-grpc`, `zio-http`.

Individual endpoints can be annotated with content length limit:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ trait AkkaHttpServerInterpreter {
toResponseBody: (Materializer, ExecutionContext) => ToResponseBody[AkkaResponseBody, AkkaStreams]
)(ses: List[ServerEndpoint[AkkaStreams with WebSockets, Future]]): Route = {
val filterServerEndpoints = FilterServerEndpoints(ses)
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(akkaHttpServerOptions.interceptors, ses)
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(
akkaHttpServerOptions.appendInterceptor(AkkaStreamSizeExceptionInterceptor).interceptors,
ses
)

extractExecutionContext { implicit ec =>
extractMaterializer { implicit mat =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
package sttp.tapir.server.akkahttp

import akka.http.scaladsl.model.{HttpEntity, Multipart}
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.RequestContext
import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller
import akka.stream.Materializer
import akka.stream.scaladsl.{FileIO, Sink}
import akka.stream.scaladsl._
import akka.util.ByteString
import sttp.capabilities.akka.AkkaStreams
import sttp.model.{Header, Part}
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, RawBodyType, RawPart, InputStreamRange}

import java.io.{ByteArrayInputStream, InputStream}

import scala.concurrent.{ExecutionContext, Future}

private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(implicit
Expand All @@ -22,29 +20,44 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im
) extends RequestBody[Future, AkkaStreams] {
override val streams: AkkaStreams = AkkaStreams
override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] =
toRawFromEntity(request, akkeRequestEntity(request), bodyType)
toRawFromEntity(request, requestEntity(request, maxBytes), bodyType)

override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = akkeRequestEntity(request).dataBytes
maxBytes.map(AkkaStreams.limitBytes(stream, _)).getOrElse(stream)
requestEntity(request, maxBytes).dataBytes
}

private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity
private def requestEntity(request: ServerRequest, maxBytes: Option[Long]): RequestEntity = {
val entity = request.underlying.asInstanceOf[RequestContext].request.entity
maxBytes.map(entity.withSizeLimit).getOrElse(entity)
}

private def toRawFromEntity[R](request: ServerRequest, body: HttpEntity, bodyType: RawBodyType[R]): Future[RawValue[R]] = {
private def toRawFromEntity[R](
request: ServerRequest,
body: HttpEntity,
bodyType: RawBodyType[R]
): Future[RawValue[R]] = {
bodyType match {
case RawBodyType.StringBody(_) => implicitly[FromEntityUnmarshaller[String]].apply(body).map(RawValue(_))
case RawBodyType.ByteArrayBody => implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(RawValue(_))
case RawBodyType.ByteBufferBody => implicitly[FromEntityUnmarshaller[ByteString]].apply(body).map(b => RawValue(b.asByteBuffer))
case RawBodyType.InputStreamBody =>
implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(b => RawValue(new ByteArrayInputStream(b)))
Future.successful(RawValue(body.dataBytes.runWith(StreamConverters.asInputStream())))
case RawBodyType.FileBody =>
serverOptions
.createFile(request)
.flatMap(file => body.dataBytes.runWith(FileIO.toPath(file.toPath)).map(_ => FileRange(file)).map(f => RawValue(f, Seq(f))))
.flatMap(file =>
body.dataBytes
.runWith(FileIO.toPath(file.toPath))
.recoverWith {
// We need to dig out EntityStreamSizeException from an external wrapper applied by FileIO sink
case e: Exception if e.getCause().isInstanceOf[EntityStreamSizeException] =>
Future.failed(e.getCause())
}
.map(_ => FileRange(file))
.map(f => RawValue(f, Seq(f)))
)
case RawBodyType.InputStreamRangeBody =>
implicitly[FromEntityUnmarshaller[Array[Byte]]]
.apply(body)
.map(b => RawValue(InputStreamRange(() => new ByteArrayInputStream(b))))
Future.successful(RawValue(InputStreamRange(() => body.dataBytes.runWith(StreamConverters.asInputStream()))))
case m: RawBodyType.MultipartBody =>
implicitly[FromEntityUnmarshaller[Multipart.FormData]].apply(body).flatMap { fd =>
fd.parts
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package sttp.tapir.server.akkahttp

import akka.http.scaladsl.model.EntityStreamSizeException
import sttp.capabilities.StreamMaxLengthExceededException
import sttp.monad.MonadError
import sttp.tapir.server.interceptor.exception.{ExceptionContext, ExceptionHandler, ExceptionInterceptor}
import sttp.tapir.server.model.ValuedEndpointOutput

import scala.concurrent.Future

/** Used by AkkaHttpServerInterpreter to catch specific scenarios related to exceeding max content length in requests:
* - EntityStreamSizeException thrown when InputBody is an Akka Stream, which exceeds max length limit during processing in serverLogic.
* - A wrapped EntityStreamSizeException failure, a variant of previous scenario where additional stage (like FileIO sink) wraps the
* underlying cause into another exception.
* - An InputStreamBody throws an IOException(EntityStreamSizeException) when reading the input stream fails due to exceeding max length
* limit in the underlying Akka Stream.
*
* All these scenarios mean basically the same, so we'll fail with our own StreamMaxLengthExceededException, a general mechanism intended
* to be handled by Tapir and result in a HTTP 413 Payload Too Large response.
*/
private[akkahttp] object AkkaStreamSizeExceptionInterceptor
extends ExceptionInterceptor[Future](new ExceptionHandler[Future] {
override def apply(ctx: ExceptionContext)(implicit monad: MonadError[Future]): Future[Option[ValuedEndpointOutput[_]]] = {
ctx.e match {
case ex: Exception if ex.getCause().isInstanceOf[EntityStreamSizeException] =>
monad.error(StreamMaxLengthExceededException(ex.getCause().asInstanceOf[EntityStreamSizeException].limit))
case EntityStreamSizeException(limit, _) =>
monad.error(StreamMaxLengthExceededException(limit))
case other =>
monad.error(other)
}
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {
stream.runWith(Sink.ignore).map(_ => ())

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(AkkaStreams)(drainAkka) ++
new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++
new ServerWebSocketTests(createServerTest, AkkaStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f)
override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class ArmeriaCatsServerTest extends TestSuite {
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2)
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class ArmeriaFutureServerTest extends TestSuite {
val interpreter = new ArmeriaTestFutureServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class ArmeriaZioServerTest extends TestSuite {
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream)
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ class FinatraServerCatsTests extends TestSuite {
val interpreter = new FinatraCatsTestServerInterpreter(dispatcher)
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false, metrics = false).tests() ++
new AllServerTests(
createServerTest,
interpreter,
backend,
staticContent = false,
reject = false,
metrics = false,
maxContentLength = false
).tests() ++
new ServerFilesTests(interpreter, backend, supportSettingContentLength = false).tests()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class FinatraServerTest extends TestSuite {
val interpreter = new FinatraTestServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false, maxContentLength = false).tests() ++
new ServerFilesTests(interpreter, backend, supportSettingContentLength = false).tests()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, RawPart}

import java.io.ByteArrayInputStream
import org.http4s.Media
import org.http4s.Headers

private[http4s] class Http4sRequestBody[F[_]: Async](
serverOptions: Http4sServerOptions[F]
) extends RequestBody[F, Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = {
toRawFromStream(serverRequest, toStream(serverRequest, maxBytes), bodyType, http4sRequest(serverRequest).charset)
toRawFromStream(serverRequest, toStream(serverRequest, maxBytes), bodyType, http4sRequest(serverRequest).charset, maxBytes)
}
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = http4sRequest(serverRequest).body
Expand All @@ -32,7 +34,8 @@ private[http4s] class Http4sRequestBody[F[_]: Async](
serverRequest: ServerRequest,
body: fs2.Stream[F, Byte],
bodyType: RawBodyType[R],
charset: Option[Charset]
charset: Option[Charset],
maxBytes: Option[Long]
): F[RawValue[R]] = {
def asChunk: F[Chunk[Byte]] = body.compile.to(Chunk)
def asByteArray: F[Array[Byte]] = body.compile.to(Chunk).map(_.toArray[Byte])
Expand All @@ -52,26 +55,39 @@ private[http4s] class Http4sRequestBody[F[_]: Async](
}
case m: RawBodyType.MultipartBody =>
// TODO: use MultipartDecoder.mixedMultipart once available?
implicitly[EntityDecoder[F, multipart.Multipart[F]]].decode(http4sRequest(serverRequest), strict = false).value.flatMap {
case Left(failure) => Sync[F].raiseError(failure)
case Right(mp) =>
val rawPartsF: Vector[F[RawPart]] = mp.parts
.flatMap(part => part.name.flatMap(name => m.partType(name)).map((part, _)).toList)
.map { case (part, codecMeta) => toRawPart(serverRequest, part, codecMeta).asInstanceOf[F[RawPart]] }
implicitly[EntityDecoder[F, multipart.Multipart[F]]]
.decode(limitedMedia(http4sRequest(serverRequest), maxBytes), strict = false)
.value
.flatMap {
case Left(failure) => Sync[F].raiseError(failure)
case Right(mp) =>
val rawPartsF: Vector[F[RawPart]] = mp.parts
.flatMap(part => part.name.flatMap(name => m.partType(name)).map((part, _)).toList)
.map { case (part, codecMeta) => toRawPart(serverRequest, part, codecMeta).asInstanceOf[F[RawPart]] }

val rawParts: F[RawValue[Vector[RawPart]]] = rawPartsF.sequence.map { parts =>
RawValue(parts, parts collect { case _ @Part(_, f: FileRange, _, _) => f })
}
val rawParts: F[RawValue[Vector[RawPart]]] = rawPartsF.sequence.map { parts =>
RawValue(parts, parts collect { case _ @Part(_, f: FileRange, _, _) => f })
}

rawParts.asInstanceOf[F[RawValue[R]]] // R is Vector[RawPart]
}
rawParts.asInstanceOf[F[RawValue[R]]] // R is Vector[RawPart]
}
}
}

private def limitedMedia(media: Media[F], maxBytes: Option[Long]): Media[F] = maxBytes
.map(limit =>
new Media[F] {
override def body: fs2.Stream[F, Byte] = Fs2Streams.limitBytes(media.body, limit)
override def headers: Headers = media.headers
override def covary[F2[x] >: F[x]]: Media[F2] = media.covary
}
)
.getOrElse(media)

private def toRawPart[R](serverRequest: ServerRequest, part: multipart.Part[F], partType: RawBodyType[R]): F[Part[R]] = {
val dispositionParams = part.headers.get[`Content-Disposition`].map(_.parameters).getOrElse(Map.empty)
val charset = part.headers.get[`Content-Type`].flatMap(_.charset)
toRawFromStream(serverRequest, part.body, partType, charset)
toRawFromStream(serverRequest, part.body, partType, charset, maxBytes = None)
.map(r =>
Part(
part.name.getOrElse(""),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++
new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++
new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ZHttp4sServerTest extends TestSuite with OptionValues {
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++
new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
val file = createFile(serverRequest)
Files.copy(asInputStream, file.toPath, StandardCopyOption.REPLACE_EXISTING)
RawValue(FileRange(file), Seq(FileRange(file)))
case m: RawBodyType.MultipartBody => RawValue.fromParts(multiPartRequestToRawBody(serverRequest, m))
case m: RawBodyType.MultipartBody => RawValue.fromParts(multiPartRequestToRawBody(serverRequest, asInputStream, m))
}
}

Expand All @@ -57,11 +57,11 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
.getOrElse(throw new IllegalArgumentException("Unable to extract multipart boundary from multipart request"))
}

private def multiPartRequestToRawBody(request: ServerRequest, m: RawBodyType.MultipartBody): Seq[RawPart] = {
private def multiPartRequestToRawBody(request: ServerRequest, requestBody: InputStream, m: RawBodyType.MultipartBody): Seq[RawPart] = {
val httpExchange = jdkHttpRequest(request)
val boundary = extractBoundary(httpExchange)

parseMultipartBody(httpExchange.getRequestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart =>
parseMultipartBody(requestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart =>
parsedPart.getName.flatMap(name =>
m.partType(name)
.map(partType => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class JdkHttpServerTest extends TestSuite with EitherValues {
val interpreter = new JdkHttpTestServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false, maxContentLength = true).tests() ++
new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false).tests()
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
createServerTest,
interpreter,
backend,
multipart = false,
maxContentLength = true
multipart = false
)
.tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests()
Expand Down
Loading

0 comments on commit 716ce3a

Please sign in to comment.