Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxContentLength for http4s, jdkhttp, play #3374

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/endpoint/security.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Optional and multiple authentication inputs have some additional rules as to how
## Limiting request body length

*Supported backends*:
Feature enabled only for Netty-based servers. More backends will be added in the near future.
This feature is available for backends based on http4s, jdkhttp, Netty, and Play. More backends will be added in the near future.

Individual endpoints can be annotated with content length limit:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ private[http4s] class Http4sRequestBody[F[_]: Async](
) extends RequestBody[F, Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]
override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = {
val r = http4sRequest(serverRequest)
toRawFromStream(serverRequest, r.body, bodyType, r.charset)
toRawFromStream(serverRequest, toStream(serverRequest, maxBytes), bodyType, http4sRequest(serverRequest).charset)
}
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val stream = http4sRequest(serverRequest).body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++
new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ZHttp4sServerTest extends TestSuite with OptionValues {
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++
new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): RawValue[RAW] = {
val request = jdkHttpRequest(serverRequest)
toRaw(serverRequest, bodyType, request.getRequestBody)
toRaw(serverRequest, bodyType, request.getRequestBody, maxBytes)
}

private def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], body: InputStream): RawValue[RAW] = {
def asInputStream: InputStream = body
private def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], body: InputStream, maxBytes: Option[Long]): RawValue[RAW] = {
def asInputStream: InputStream = maxBytes.map(limit => new FailingLimitedInputStream(body, limit)).getOrElse(body)
def asByteArray: Array[Byte] = asInputStream.readAllBytes()

bodyType match {
Expand Down Expand Up @@ -65,7 +65,7 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
parsedPart.getName.flatMap(name =>
m.partType(name)
.map(partType => {
val bodyRawValue = toRaw(request, partType, parsedPart.getBody)
val bodyRawValue = toRaw(request, partType, parsedPart.getBody, maxBytes = None)
Part(
name,
bodyRawValue.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,3 @@ private class ByteBufferBackedInputStream(buf: ByteBuffer) extends InputStream {
}
}

private class LimitedInputStream(delegate: InputStream, var limit: Long) extends InputStream {
override def read(): Int = {
if (limit == 0L) -1
else {
limit -= 1
delegate.read()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package sttp.tapir.server.jdkhttp.internal

import sttp.capabilities.StreamMaxLengthExceededException
import java.io.FilterInputStream
import java.io.InputStream
import java.io.IOException

class FailingLimitedInputStream(in: InputStream, limit: Long) extends LimitedInputStream(in, limit) {
override def onLimit: Int = {
throw new StreamMaxLengthExceededException(limit)
}
}

/** Based on Guava's https://github.com/google/guava/blob/master/guava/src/com/google/common/io/ByteStreams.java
*/
class LimitedInputStream(in: InputStream, limit: Long) extends FilterInputStream(in) {
protected var left: Long = limit
private var mark: Long = -1L

override def available(): Int = Math.min(in.available(), left.toInt)

override def mark(readLimit: Int): Unit = this.synchronized {
in.mark(readLimit)
mark = left
}

override def read(): Int = this.synchronized {
if (left == 0) {
onLimit
} else {
val result = in.read()
if (result != -1) {
left -= 1
}
result
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = this.synchronized {
if (left == 0) {
// Temporarily perform a read to check if more bytes are available
val checkRead = in.read()
if (checkRead == -1) {
-1 // No more bytes available in the stream
} else {
onLimit
}
} else {
val adjustedLen = Math.min(len, left.toInt)
val result = in.read(b, off, adjustedLen)
if (result != -1) {
left -= result
}
result
}
}

override def reset(): Unit = this.synchronized {
if (!in.markSupported) {
throw new IOException("Mark not supported")
}
if (mark == -1) {
throw new IOException("Mark not set")
}

in.reset()
left = mark
}

override def skip(n: Long): Long = this.synchronized {
val toSkip = Math.min(n, left)
val skipped = in.skip(toSkip)
left -= skipped
skipped
}

protected def onLimit: Int = -1
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class JdkHttpServerTest extends TestSuite with EitherValues {
val interpreter = new JdkHttpTestServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++
new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false, maxContentLength = true).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false).tests()
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
) extends RequestBody[Future, PekkoStreams] {

override val streams: PekkoStreams = PekkoStreams
val parsers = serverOptions.playBodyParsers

override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = {
import mat.executionContext
val request = playRequest(serverRequest)
val charset = request.charset.map(Charset.forName)
toRaw(request, bodyType, charset, () => request.body, None)
toRaw(request, bodyType, charset, () => request.body, None, maxBytes)
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
Expand All @@ -40,17 +41,23 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
bodyType: RawBodyType[R],
charset: Option[Charset],
body: () => Source[ByteString, Any],
bodyAsFile: Option[File]
bodyAsFile: Option[File],
maxBytes: Option[Long]
)(implicit
mat: Materializer,
ec: ExecutionContext
): Future[RawValue[R]] = {
// playBodyParsers is used, so that the maxLength limits from Play configuration are applied
def bodyAsByteString(): Future[ByteString] = {
serverOptions.playBodyParsers.byteString.apply(request).run(body()).flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(value) => Future.successful(value)
}
maxBytes
.map(parsers.byteString(_))
.getOrElse(parsers.byteString)
.apply(request)
.run(body())
.flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(value) => Future.successful(value)
}
}
bodyType match {
case RawBodyType.StringBody(defaultCharset) =>
Expand All @@ -67,10 +74,15 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
Future.successful(RawValue(tapirFile, Seq(tapirFile)))
case None =>
val file = FileRange(serverOptions.temporaryFileCreator.create().toFile)
serverOptions.playBodyParsers.file(file.file).apply(request).run(body()).flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(_) => Future.successful(RawValue(file, Seq(file)))
}
maxBytes
.map(parsers.file(file.file, _))
.getOrElse(parsers.file(file.file))
.apply(request)
.run(body())
.flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(_) => Future.successful(RawValue(file, Seq(file)))
}
}
case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body)
}
Expand Down Expand Up @@ -100,7 +112,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
partType,
charset(partType),
() => Source(data),
None
bodyAsFile = None,
maxBytes = None
).map(body => Some(Part(key, body.value)))
}
}.toSeq
Expand All @@ -113,7 +126,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
partType,
charset(partType),
() => FileIO.fromPath(f.ref.path),
Some(f.ref.toFile)
Some(f.ref.toFile),
maxBytes = None
).map(body =>
Some(
Part(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sttp.tapir.server.play

import org.apache.pekko.actor.ActorSystem
import enumeratum._
import org.apache.pekko.stream.scaladsl.{Flow, Sink, Source}
import cats.data.NonEmptyList
import cats.effect.{IO, Resource}
Expand All @@ -17,8 +16,6 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum
import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler

class PlayServerTest extends TestSuite {

Expand Down Expand Up @@ -112,10 +109,18 @@ class PlayServerTest extends TestSuite {
interpreter,
multipleValueHeaderSupport = false,
inputStreamSupport = false,
invulnerableToUnsanitizedHeaders = false
invulnerableToUnsanitizedHeaders = false,
maxContentLength = true
).tests() ++
new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++
new AllServerTests(
createServerTest,
interpreter,
backend,
basic = false,
multipart = false,
options = false
).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(PekkoStreams)(drainPekko) ++
new PlayServerWithContextTest(backend).tests() ++
new ServerWebSocketTests(createServerTest, PekkoStreams) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
) extends RequestBody[Future, AkkaStreams] {

override val streams: AkkaStreams = AkkaStreams
val parsers = serverOptions.playBodyParsers

override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = {
import mat.executionContext
val request = playRequest(serverRequest)
val charset = request.charset.map(Charset.forName)
toRaw(request, bodyType, charset, () => request.body, None)
toRaw(request, bodyType, charset, () => request.body, None, maxBytes)
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
Expand All @@ -40,17 +41,23 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
bodyType: RawBodyType[R],
charset: Option[Charset],
body: () => Source[ByteString, Any],
bodyAsFile: Option[File]
bodyAsFile: Option[File],
maxBytes: Option[Long]
)(implicit
mat: Materializer,
ec: ExecutionContext
): Future[RawValue[R]] = {
// playBodyParsers is used, so that the maxLength limits from Play configuration are applied
def bodyAsByteString(): Future[ByteString] = {
serverOptions.playBodyParsers.byteString.apply(request).run(body()).flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(value) => Future.successful(value)
}
maxBytes
.map(parsers.byteString(_))
.getOrElse(parsers.byteString)
.apply(request)
.run(body())
.flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(value) => Future.successful(value)
}
}
bodyType match {
case RawBodyType.StringBody(defaultCharset) =>
Expand All @@ -67,10 +74,15 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
Future.successful(RawValue(tapirFile, Seq(tapirFile)))
case None =>
val file = FileRange(serverOptions.temporaryFileCreator.create().toFile)
serverOptions.playBodyParsers.file(file.file).apply(request).run(body()).flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(_) => Future.successful(RawValue(file, Seq(file)))
}
maxBytes
.map(parsers.file(file.file, _))
.getOrElse(parsers.file(file.file))
.apply(request)
.run(body())
.flatMap {
case Left(result) => Future.failed(new PlayBodyParserException(result))
case Right(_) => Future.successful(RawValue(file, Seq(file)))
}
}
case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body)
}
Expand Down Expand Up @@ -100,7 +112,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
partType,
charset(partType),
() => Source(data),
None
bodyAsFile = None,
maxBytes = None
).map(body => Some(Part(key, body.value)))
}
}.toSeq
Expand All @@ -113,7 +126,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit
partType,
charset(partType),
() => FileIO.fromPath(f.ref.path),
Some(f.ref.toFile)
Some(f.ref.toFile),
maxBytes = None
).map(body =>
Some(
Part(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class PlayServerTest extends TestSuite {
interpreter,
multipleValueHeaderSupport = false,
inputStreamSupport = false,
invulnerableToUnsanitizedHeaders = false
invulnerableToUnsanitizedHeaders = false,
maxContentLength = true
).tests() ++
new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++
Expand Down
Loading