diff --git a/doc/endpoint/security.md b/doc/endpoint/security.md index 859676c391..a92af08cff 100644 --- a/doc/endpoint/security.md +++ b/doc/endpoint/security.md @@ -51,8 +51,8 @@ Optional and multiple authentication inputs have some additional rules as to how ## Limiting request body length -*Supported backends*: -This feature is available for backends based on http4s, jdkhttp, Netty, and Play. More backends will be added in the near future. +*Unsupported backends*: +This feature is available for all server backends *except*: `akka-grpc`, `Armeria`, `Finatra`, `Helidon Nima`, `pekko-grpc`, `zio-http`. Individual endpoints can be annotated with content length limit: diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaHttpServerInterpreter.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaHttpServerInterpreter.scala index da3f98febe..7204c631e6 100644 --- a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaHttpServerInterpreter.scala +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaHttpServerInterpreter.scala @@ -44,7 +44,10 @@ trait AkkaHttpServerInterpreter { toResponseBody: (Materializer, ExecutionContext) => ToResponseBody[AkkaResponseBody, AkkaStreams] )(ses: List[ServerEndpoint[AkkaStreams with WebSockets, Future]]): Route = { val filterServerEndpoints = FilterServerEndpoints(ses) - val interceptors = RejectInterceptor.disableWhenSingleEndpoint(akkaHttpServerOptions.interceptors, ses) + val interceptors = RejectInterceptor.disableWhenSingleEndpoint( + akkaHttpServerOptions.appendInterceptor(AkkaStreamSizeExceptionInterceptor).interceptors, + ses + ) extractExecutionContext { implicit ec => extractMaterializer { implicit mat => diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala index 19654402f7..91e654e5bc 100644 --- a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaRequestBody.scala @@ -1,10 +1,10 @@ package sttp.tapir.server.akkahttp -import akka.http.scaladsl.model.{HttpEntity, Multipart} +import akka.http.scaladsl.model._ import akka.http.scaladsl.server.RequestContext import akka.http.scaladsl.unmarshalling.FromEntityUnmarshaller import akka.stream.Materializer -import akka.stream.scaladsl.{FileIO, Sink} +import akka.stream.scaladsl._ import akka.util.ByteString import sttp.capabilities.akka.AkkaStreams import sttp.model.{Header, Part} @@ -12,8 +12,6 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.{RawValue, RequestBody} import sttp.tapir.{FileRange, RawBodyType, RawPart, InputStreamRange} -import java.io.{ByteArrayInputStream, InputStream} - import scala.concurrent.{ExecutionContext, Future} private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(implicit @@ -22,29 +20,44 @@ private[akkahttp] class AkkaRequestBody(serverOptions: AkkaHttpServerOptions)(im ) extends RequestBody[Future, AkkaStreams] { override val streams: AkkaStreams = AkkaStreams override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = - toRawFromEntity(request, akkeRequestEntity(request), bodyType) + toRawFromEntity(request, requestEntity(request, maxBytes), bodyType) + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val stream = akkeRequestEntity(request).dataBytes - maxBytes.map(AkkaStreams.limitBytes(stream, _)).getOrElse(stream) + requestEntity(request, maxBytes).dataBytes } - private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity + private def requestEntity(request: ServerRequest, maxBytes: Option[Long]): RequestEntity = { + val entity = request.underlying.asInstanceOf[RequestContext].request.entity + maxBytes.map(entity.withSizeLimit).getOrElse(entity) + } - private def toRawFromEntity[R](request: ServerRequest, body: HttpEntity, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + private def toRawFromEntity[R]( + request: ServerRequest, + body: HttpEntity, + bodyType: RawBodyType[R] + ): Future[RawValue[R]] = { bodyType match { case RawBodyType.StringBody(_) => implicitly[FromEntityUnmarshaller[String]].apply(body).map(RawValue(_)) case RawBodyType.ByteArrayBody => implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(RawValue(_)) case RawBodyType.ByteBufferBody => implicitly[FromEntityUnmarshaller[ByteString]].apply(body).map(b => RawValue(b.asByteBuffer)) case RawBodyType.InputStreamBody => - implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(b => RawValue(new ByteArrayInputStream(b))) + Future.successful(RawValue(body.dataBytes.runWith(StreamConverters.asInputStream()))) case RawBodyType.FileBody => serverOptions .createFile(request) - .flatMap(file => body.dataBytes.runWith(FileIO.toPath(file.toPath)).map(_ => FileRange(file)).map(f => RawValue(f, Seq(f)))) + .flatMap(file => + body.dataBytes + .runWith(FileIO.toPath(file.toPath)) + .recoverWith { + // We need to dig out EntityStreamSizeException from an external wrapper applied by FileIO sink + case e: Exception if e.getCause().isInstanceOf[EntityStreamSizeException] => + Future.failed(e.getCause()) + } + .map(_ => FileRange(file)) + .map(f => RawValue(f, Seq(f))) + ) case RawBodyType.InputStreamRangeBody => - implicitly[FromEntityUnmarshaller[Array[Byte]]] - .apply(body) - .map(b => RawValue(InputStreamRange(() => new ByteArrayInputStream(b)))) + Future.successful(RawValue(InputStreamRange(() => body.dataBytes.runWith(StreamConverters.asInputStream())))) case m: RawBodyType.MultipartBody => implicitly[FromEntityUnmarshaller[Multipart.FormData]].apply(body).flatMap { fd => fd.parts diff --git a/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaStreamSizeExceptionInterceptor.scala b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaStreamSizeExceptionInterceptor.scala new file mode 100644 index 0000000000..ba971f7bd2 --- /dev/null +++ b/server/akka-http-server/src/main/scala/sttp/tapir/server/akkahttp/AkkaStreamSizeExceptionInterceptor.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.akkahttp + +import akka.http.scaladsl.model.EntityStreamSizeException +import sttp.capabilities.StreamMaxLengthExceededException +import sttp.monad.MonadError +import sttp.tapir.server.interceptor.exception.{ExceptionContext, ExceptionHandler, ExceptionInterceptor} +import sttp.tapir.server.model.ValuedEndpointOutput + +import scala.concurrent.Future + +/** Used by AkkaHttpServerInterpreter to catch specific scenarios related to exceeding max content length in requests: + * - EntityStreamSizeException thrown when InputBody is an Akka Stream, which exceeds max length limit during processing in serverLogic. + * - A wrapped EntityStreamSizeException failure, a variant of previous scenario where additional stage (like FileIO sink) wraps the + * underlying cause into another exception. + * - An InputStreamBody throws an IOException(EntityStreamSizeException) when reading the input stream fails due to exceeding max length + * limit in the underlying Akka Stream. + * + * All these scenarios mean basically the same, so we'll fail with our own StreamMaxLengthExceededException, a general mechanism intended + * to be handled by Tapir and result in a HTTP 413 Payload Too Large response. + */ +private[akkahttp] object AkkaStreamSizeExceptionInterceptor + extends ExceptionInterceptor[Future](new ExceptionHandler[Future] { + override def apply(ctx: ExceptionContext)(implicit monad: MonadError[Future]): Future[Option[ValuedEndpointOutput[_]]] = { + ctx.e match { + case ex: Exception if ex.getCause().isInstanceOf[EntityStreamSizeException] => + monad.error(StreamMaxLengthExceededException(ex.getCause().asInstanceOf[EntityStreamSizeException].limit)) + case EntityStreamSizeException(limit, _) => + monad.error(StreamMaxLengthExceededException(limit)) + case other => + monad.error(other) + } + } + }) diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 209602d4df..3cb5e63b73 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -156,7 +156,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { stream.runWith(Sink.ignore).map(_ => ()) new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(AkkaStreams)(drainAkka) ++ + new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new ServerWebSocketTests(createServerTest, AkkaStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala index cde75c2a91..4181406515 100644 --- a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala +++ b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala @@ -16,8 +16,8 @@ class ArmeriaCatsServerTest extends TestSuite { def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = stream.compile.drain.void - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ - new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ + new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) } } diff --git a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala index 22f546f4da..3c7d99de62 100644 --- a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala +++ b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala @@ -15,8 +15,8 @@ class ArmeriaFutureServerTest extends TestSuite { val interpreter = new ArmeriaTestFutureServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ - new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit) } } diff --git a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala index 17df98f35a..1b00ad09ef 100644 --- a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala +++ b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala @@ -20,8 +20,8 @@ class ArmeriaZioServerTest extends TestSuite { def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = zStream.run(ZSink.drain) - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false).tests() ++ - new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ + new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) } } diff --git a/server/finatra-server/cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala b/server/finatra-server/cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala index 2a2e6b3d8e..2e365acc1b 100644 --- a/server/finatra-server/cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala +++ b/server/finatra-server/cats/src/test/scala/sttp/tapir/server/finatra/cats/FinatraServerCatsTests.scala @@ -12,7 +12,15 @@ class FinatraServerCatsTests extends TestSuite { val interpreter = new FinatraCatsTestServerInterpreter(dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false, metrics = false).tests() ++ + new AllServerTests( + createServerTest, + interpreter, + backend, + staticContent = false, + reject = false, + metrics = false, + maxContentLength = false + ).tests() ++ new ServerFilesTests(interpreter, backend, supportSettingContentLength = false).tests() } } diff --git a/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala b/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala index 6b1d4d7d7f..e53603fdbf 100644 --- a/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala +++ b/server/finatra-server/src/test/scala/sttp/tapir/server/finatra/FinatraServerTest.scala @@ -11,7 +11,7 @@ class FinatraServerTest extends TestSuite { val interpreter = new FinatraTestServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, staticContent = false, reject = false, maxContentLength = false).tests() ++ new ServerFilesTests(interpreter, backend, supportSettingContentLength = false).tests() } } diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala index 6ad163b5ee..187c046594 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sRequestBody.scala @@ -13,13 +13,15 @@ import sttp.tapir.server.interpreter.{RawValue, RequestBody} import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, RawPart} import java.io.ByteArrayInputStream +import org.http4s.Media +import org.http4s.Headers private[http4s] class Http4sRequestBody[F[_]: Async]( serverOptions: Http4sServerOptions[F] ) extends RequestBody[F, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { - toRawFromStream(serverRequest, toStream(serverRequest, maxBytes), bodyType, http4sRequest(serverRequest).charset) + toRawFromStream(serverRequest, toStream(serverRequest, maxBytes), bodyType, http4sRequest(serverRequest).charset, maxBytes) } override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { val stream = http4sRequest(serverRequest).body @@ -32,7 +34,8 @@ private[http4s] class Http4sRequestBody[F[_]: Async]( serverRequest: ServerRequest, body: fs2.Stream[F, Byte], bodyType: RawBodyType[R], - charset: Option[Charset] + charset: Option[Charset], + maxBytes: Option[Long] ): F[RawValue[R]] = { def asChunk: F[Chunk[Byte]] = body.compile.to(Chunk) def asByteArray: F[Array[Byte]] = body.compile.to(Chunk).map(_.toArray[Byte]) @@ -52,26 +55,39 @@ private[http4s] class Http4sRequestBody[F[_]: Async]( } case m: RawBodyType.MultipartBody => // TODO: use MultipartDecoder.mixedMultipart once available? - implicitly[EntityDecoder[F, multipart.Multipart[F]]].decode(http4sRequest(serverRequest), strict = false).value.flatMap { - case Left(failure) => Sync[F].raiseError(failure) - case Right(mp) => - val rawPartsF: Vector[F[RawPart]] = mp.parts - .flatMap(part => part.name.flatMap(name => m.partType(name)).map((part, _)).toList) - .map { case (part, codecMeta) => toRawPart(serverRequest, part, codecMeta).asInstanceOf[F[RawPart]] } + implicitly[EntityDecoder[F, multipart.Multipart[F]]] + .decode(limitedMedia(http4sRequest(serverRequest), maxBytes), strict = false) + .value + .flatMap { + case Left(failure) => Sync[F].raiseError(failure) + case Right(mp) => + val rawPartsF: Vector[F[RawPart]] = mp.parts + .flatMap(part => part.name.flatMap(name => m.partType(name)).map((part, _)).toList) + .map { case (part, codecMeta) => toRawPart(serverRequest, part, codecMeta).asInstanceOf[F[RawPart]] } - val rawParts: F[RawValue[Vector[RawPart]]] = rawPartsF.sequence.map { parts => - RawValue(parts, parts collect { case _ @Part(_, f: FileRange, _, _) => f }) - } + val rawParts: F[RawValue[Vector[RawPart]]] = rawPartsF.sequence.map { parts => + RawValue(parts, parts collect { case _ @Part(_, f: FileRange, _, _) => f }) + } - rawParts.asInstanceOf[F[RawValue[R]]] // R is Vector[RawPart] - } + rawParts.asInstanceOf[F[RawValue[R]]] // R is Vector[RawPart] + } } } + private def limitedMedia(media: Media[F], maxBytes: Option[Long]): Media[F] = maxBytes + .map(limit => + new Media[F] { + override def body: fs2.Stream[F, Byte] = Fs2Streams.limitBytes(media.body, limit) + override def headers: Headers = media.headers + override def covary[F2[x] >: F[x]]: Media[F2] = media.covary + } + ) + .getOrElse(media) + private def toRawPart[R](serverRequest: ServerRequest, part: multipart.Part[F], partType: RawBodyType[R]): F[Part[R]] = { val dispositionParams = part.headers.get[`Content-Disposition`].map(_.parameters).getOrElse(Map.empty) val charset = part.headers.get[`Content-Type`].flatMap(_.charset) - toRawFromStream(serverRequest, part.body, partType, charset) + toRawFromStream(serverRequest, part.body, partType, charset, maxBytes = None) .map(r => Part( part.name.getOrElse(""), diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 00fc79bdd5..844e628024 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -136,8 +136,8 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = stream.compile.drain.void - new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ + new AllServerTests(createServerTest, interpreter, backend).tests() ++ + new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index a0050a094d..99ff7c2d4d 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -53,8 +53,8 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = zStream.run(ZSink.drain) - new AllServerTests(createServerTest, interpreter, backend, maxContentLength = true).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ + new AllServerTests(createServerTest, interpreter, backend).tests() ++ + new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala index d3a3be59b5..0cb5fe8528 100644 --- a/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala +++ b/server/jdkhttp-server/src/main/scala/sttp/tapir/server/jdkhttp/internal/JdkHttpRequestBody.scala @@ -37,7 +37,7 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile val file = createFile(serverRequest) Files.copy(asInputStream, file.toPath, StandardCopyOption.REPLACE_EXISTING) RawValue(FileRange(file), Seq(FileRange(file))) - case m: RawBodyType.MultipartBody => RawValue.fromParts(multiPartRequestToRawBody(serverRequest, m)) + case m: RawBodyType.MultipartBody => RawValue.fromParts(multiPartRequestToRawBody(serverRequest, asInputStream, m)) } } @@ -57,11 +57,11 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile .getOrElse(throw new IllegalArgumentException("Unable to extract multipart boundary from multipart request")) } - private def multiPartRequestToRawBody(request: ServerRequest, m: RawBodyType.MultipartBody): Seq[RawPart] = { + private def multiPartRequestToRawBody(request: ServerRequest, requestBody: InputStream, m: RawBodyType.MultipartBody): Seq[RawPart] = { val httpExchange = jdkHttpRequest(request) val boundary = extractBoundary(httpExchange) - parseMultipartBody(httpExchange.getRequestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart => + parseMultipartBody(requestBody, boundary, multipartFileThresholdBytes).flatMap(parsedPart => parsedPart.getName.flatMap(name => m.partType(name) .map(partType => { diff --git a/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala index a6e8eb9854..5e3f42cfd4 100644 --- a/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala +++ b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala @@ -14,7 +14,7 @@ class JdkHttpServerTest extends TestSuite with EitherValues { val interpreter = new JdkHttpTestServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false, maxContentLength = true).tests() ++ + new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++ new AllServerTests(createServerTest, interpreter, backend, basic = false).tests() }) } diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 52d86ae1d4..e7a1280719 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -34,11 +34,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues { createServerTest, interpreter, backend, - multipart = false, - maxContentLength = true + multipart = false ) .tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams[IO])(drainFs2) ++ + new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala index 73fe2cdd20..e4053d18be 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettyIdServerTest.scala @@ -22,7 +22,7 @@ class NettyIdServerTest extends TestSuite with EitherValues { val sleeper: Sleeper[Id] = (duration: FiniteDuration) => Thread.sleep(duration.toMillis) val tests = - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false, maxContentLength = true) + new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false) .tests() ++ new ServerGracefulShutdownTests(createServerTest, sleeper).tests() diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index 5125a86532..f8f263706e 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -22,7 +22,7 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) val tests = - new AllServerTests(createServerTest, interpreter, backend, multipart = false, maxContentLength = true).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index 134de29376..72c3580f3d 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -39,10 +39,9 @@ class NettyZioServerTest extends TestSuite with EitherValues { interpreter, backend, staticContent = false, - multipart = false, - maxContentLength = true + multipart = false ).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ + new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() ++ new ServerGracefulShutdownTests(createServerTest, zioSleeper).tests() diff --git a/server/nima-server/src/test/scala/sttp/tapir/server/nima/NimaServerTest.scala b/server/nima-server/src/test/scala/sttp/tapir/server/nima/NimaServerTest.scala index 7b34ce5541..af5bb5ebe8 100644 --- a/server/nima-server/src/test/scala/sttp/tapir/server/nima/NimaServerTest.scala +++ b/server/nima-server/src/test/scala/sttp/tapir/server/nima/NimaServerTest.scala @@ -14,7 +14,7 @@ class NimaServerTest extends TestSuite with EitherValues { val interpreter = new NimaTestServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) // TODO uncomment static content tests when Nima starts to correctly support '*' in accept-encoding - new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++ + new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false, maxContentLength = false).tests() ++ new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, staticContent = false).tests() }) } diff --git a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerInterpreter.scala b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerInterpreter.scala index a796316ee4..cb91cc5d4e 100644 --- a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerInterpreter.scala +++ b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerInterpreter.scala @@ -20,11 +20,11 @@ import sttp.capabilities.pekko.PekkoStreams import sttp.model.Method import sttp.monad.FutureMonad import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.pekkohttp.PekkoModel.parseHeadersOrThrowWithoutContentHeaders import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, RequestBody, ServerInterpreter, ToResponseBody} import sttp.tapir.server.model.ServerResponse +import sttp.tapir.server.pekkohttp.PekkoModel.parseHeadersOrThrowWithoutContentHeaders import scala.concurrent.{ExecutionContext, Future} @@ -44,7 +44,10 @@ trait PekkoHttpServerInterpreter { toResponseBody: (Materializer, ExecutionContext) => ToResponseBody[PekkoResponseBody, PekkoStreams] )(ses: List[ServerEndpoint[PekkoStreams with WebSockets, Future]]): Route = { val filterServerEndpoints = FilterServerEndpoints(ses) - val interceptors = RejectInterceptor.disableWhenSingleEndpoint(pekkoHttpServerOptions.interceptors, ses) + val interceptors = RejectInterceptor.disableWhenSingleEndpoint( + pekkoHttpServerOptions.appendInterceptor(PekkoStreamSizeExceptionInterceptor).interceptors, + ses + ) extractExecutionContext { implicit ec => extractMaterializer { implicit mat => diff --git a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala index 2f37d6e28c..aad35e0042 100644 --- a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala +++ b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoRequestBody.scala @@ -1,18 +1,16 @@ package sttp.tapir.server.pekkohttp -import org.apache.pekko.http.scaladsl.model.{HttpEntity, Multipart} +import org.apache.pekko.http.scaladsl.model.{EntityStreamSizeException, HttpEntity, Multipart, RequestEntity} import org.apache.pekko.http.scaladsl.server.RequestContext import org.apache.pekko.http.scaladsl.unmarshalling.FromEntityUnmarshaller +import org.apache.pekko.stream.scaladsl.{FileIO, Sink, _} import org.apache.pekko.stream.Materializer -import org.apache.pekko.stream.scaladsl.{FileIO, Sink} import org.apache.pekko.util.ByteString import sttp.capabilities.pekko.PekkoStreams import sttp.model.{Header, Part} import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, RawBodyType, RawPart, InputStreamRange} - -import java.io.ByteArrayInputStream +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, RawPart} import scala.concurrent.{ExecutionContext, Future} @@ -22,29 +20,44 @@ private[pekkohttp] class PekkoRequestBody(serverOptions: PekkoHttpServerOptions) ) extends RequestBody[Future, PekkoStreams] { override val streams: PekkoStreams = PekkoStreams override def toRaw[R](request: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = - toRawFromEntity(request, akkeRequestEntity(request), bodyType) + toRawFromEntity(request, requestEntity(request, maxBytes), bodyType) + override def toStream(request: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val stream = akkeRequestEntity(request).dataBytes - maxBytes.map(PekkoStreams.limitBytes(stream, _)).getOrElse(stream) + requestEntity(request, maxBytes).dataBytes } - private def akkeRequestEntity(request: ServerRequest) = request.underlying.asInstanceOf[RequestContext].request.entity + private def requestEntity(request: ServerRequest, maxBytes: Option[Long]): RequestEntity = { + val entity = request.underlying.asInstanceOf[RequestContext].request.entity + maxBytes.map(entity.withSizeLimit).getOrElse(entity) + } - private def toRawFromEntity[R](request: ServerRequest, body: HttpEntity, bodyType: RawBodyType[R]): Future[RawValue[R]] = { + private def toRawFromEntity[R]( + request: ServerRequest, + body: HttpEntity, + bodyType: RawBodyType[R] + ): Future[RawValue[R]] = { bodyType match { case RawBodyType.StringBody(_) => implicitly[FromEntityUnmarshaller[String]].apply(body).map(RawValue(_)) case RawBodyType.ByteArrayBody => implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(RawValue(_)) case RawBodyType.ByteBufferBody => implicitly[FromEntityUnmarshaller[ByteString]].apply(body).map(b => RawValue(b.asByteBuffer)) case RawBodyType.InputStreamBody => - implicitly[FromEntityUnmarshaller[Array[Byte]]].apply(body).map(b => RawValue(new ByteArrayInputStream(b))) + Future.successful(RawValue(body.dataBytes.runWith(StreamConverters.asInputStream()))) case RawBodyType.FileBody => serverOptions .createFile(request) - .flatMap(file => body.dataBytes.runWith(FileIO.toPath(file.toPath)).map(_ => FileRange(file)).map(f => RawValue(f, Seq(f)))) + .flatMap(file => + body.dataBytes + .runWith(FileIO.toPath(file.toPath)) + .recoverWith { + // We need to dig out EntityStreamSizeException from an external wrapper applied by FileIO sink + case e: Exception if e.getCause().isInstanceOf[EntityStreamSizeException] => + Future.failed(e.getCause()) + } + .map(_ => FileRange(file)) + .map(f => RawValue(f, Seq(f))) + ) case RawBodyType.InputStreamRangeBody => - implicitly[FromEntityUnmarshaller[Array[Byte]]] - .apply(body) - .map(b => RawValue(InputStreamRange(() => new ByteArrayInputStream(b)))) + Future.successful(RawValue(InputStreamRange(() => body.dataBytes.runWith(StreamConverters.asInputStream())))) case m: RawBodyType.MultipartBody => implicitly[FromEntityUnmarshaller[Multipart.FormData]].apply(body).flatMap { fd => fd.parts diff --git a/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoStreamSizeExceptionInterceptor.scala b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoStreamSizeExceptionInterceptor.scala new file mode 100644 index 0000000000..28078959cd --- /dev/null +++ b/server/pekko-http-server/src/main/scala/sttp/tapir/server/pekkohttp/PekkoStreamSizeExceptionInterceptor.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.pekkohttp + +import org.apache.pekko.http.scaladsl.model.EntityStreamSizeException +import sttp.capabilities.StreamMaxLengthExceededException +import sttp.monad.MonadError +import sttp.tapir.server.interceptor.exception.{ExceptionContext, ExceptionHandler, ExceptionInterceptor} +import sttp.tapir.server.model.ValuedEndpointOutput + +import scala.concurrent.Future + +/** Used by PekkoHttpServerInterpreter to catch specific scenarios related to exceeding max content length in requests: + * - EntityStreamSizeException thrown when InputBody is a Pekko Stream, which exceeds max length limit during processing in serverLogic. + * - A wrapped EntityStreamSizeException failure, a variant of previous scenario where additional stage (like FileIO sink) wraps the + * underlying cause into another exception. + * - An InputStreamBody throws an IOException(EntityStreamSizeException) when reading the input stream fails due to exceeding max length + * limit in the underlying Pekko Stream. + * + * All these scenarios mean basically the same, so we'll fail with our own StreamMaxLengthExceededException, a general mechanism intended + * to be handled by Tapir and result in a HTTP 413 Payload Too Large response. + */ +private[pekkohttp] object PekkoStreamSizeExceptionInterceptor + extends ExceptionInterceptor[Future](new ExceptionHandler[Future] { + override def apply(ctx: ExceptionContext)(implicit monad: MonadError[Future]): Future[Option[ValuedEndpointOutput[_]]] = { + ctx.e match { + case ex: Exception if ex.getCause().isInstanceOf[EntityStreamSizeException] => + monad.error(StreamMaxLengthExceededException(ex.getCause().asInstanceOf[EntityStreamSizeException].limit)) + case EntityStreamSizeException(limit, _) => + monad.error(StreamMaxLengthExceededException(limit)) + case other => + monad.error(other) + } + } + }) diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index ebfe952eb6..7f15eecf29 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -104,7 +104,7 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { stream.runWith(Sink.ignore).map(_ => ()) new AllServerTests(createServerTest, interpreter, backend).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(PekkoStreams)(drainPekko) ++ + new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new ServerWebSocketTests(createServerTest, PekkoStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala index 741a5d7e64..04383c35a8 100644 --- a/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala +++ b/server/play-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala @@ -22,7 +22,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit ) extends RequestBody[Future, PekkoStreams] { override val streams: PekkoStreams = PekkoStreams - val parsers = serverOptions.playBodyParsers + private val parsers = serverOptions.playBodyParsers + private lazy val filePartHandler = Multipart.handleFilePartAsTemporaryFile(serverOptions.temporaryFileCreator) override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { import mat.executionContext @@ -84,21 +85,20 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit case Right(_) => Future.successful(RawValue(file, Seq(file))) } } - case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body) + case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body, maxBytes) } } private def multiPartRequestToRawBody( request: Request[PekkoStreams.BinaryStream], m: RawBodyType.MultipartBody, - body: () => Source[ByteString, Any] + body: () => Source[ByteString, Any], + maxBytes: Option[Long] )(implicit mat: Materializer, ec: ExecutionContext ): Future[RawValue[Seq[RawPart]]] = { - val bodyParser = serverOptions.playBodyParsers.multipartFormData( - Multipart.handleFilePartAsTemporaryFile(serverOptions.temporaryFileCreator) - ) + val bodyParser = maxBytes.map(parsers.multipartFormData(filePartHandler, _)).getOrElse(parsers.multipartFormData(filePartHandler)) bodyParser.apply(request).run(body()).flatMap { case Left(r) => Future.failed(new PlayBodyParserException(r)) @@ -127,7 +127,7 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit charset(partType), () => FileIO.fromPath(f.ref.path), Some(f.ref.toFile), - maxBytes = None + maxBytes = None, ).map(body => Some( Part( diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index c66265359b..3662a63fd3 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -109,8 +109,7 @@ class PlayServerTest extends TestSuite { interpreter, multipleValueHeaderSupport = false, inputStreamSupport = false, - invulnerableToUnsanitizedHeaders = false, - maxContentLength = true + invulnerableToUnsanitizedHeaders = false ).tests() ++ new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ new AllServerTests( @@ -121,7 +120,7 @@ class PlayServerTest extends TestSuite { multipart = false, options = false ).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(PekkoStreams)(drainPekko) ++ + new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new PlayServerWithContextTest(backend).tests() ++ new ServerWebSocketTests(createServerTest, PekkoStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) diff --git a/server/play29-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala b/server/play29-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala index 8e1a606140..a8dda4485c 100644 --- a/server/play29-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala +++ b/server/play29-server/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala @@ -22,7 +22,8 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit ) extends RequestBody[Future, AkkaStreams] { override val streams: AkkaStreams = AkkaStreams - val parsers = serverOptions.playBodyParsers + private val parsers = serverOptions.playBodyParsers + private lazy val filePartHandler = Multipart.handleFilePartAsTemporaryFile(serverOptions.temporaryFileCreator) override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): Future[RawValue[R]] = { import mat.executionContext @@ -84,21 +85,20 @@ private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit case Right(_) => Future.successful(RawValue(file, Seq(file))) } } - case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body) + case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body, maxBytes) } } private def multiPartRequestToRawBody( request: Request[AkkaStreams.BinaryStream], m: RawBodyType.MultipartBody, - body: () => Source[ByteString, Any] + body: () => Source[ByteString, Any], + maxBytes: Option[Long] )(implicit mat: Materializer, ec: ExecutionContext ): Future[RawValue[Seq[RawPart]]] = { - val bodyParser = serverOptions.playBodyParsers.multipartFormData( - Multipart.handleFilePartAsTemporaryFile(serverOptions.temporaryFileCreator) - ) + val bodyParser = maxBytes.map(parsers.multipartFormData(filePartHandler, _)).getOrElse(parsers.multipartFormData(filePartHandler)) bodyParser.apply(request).run(body()).flatMap { case Left(r) => Future.failed(new PlayBodyParserException(r)) diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index b7fe2f1d44..096b077f44 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -111,12 +111,11 @@ class PlayServerTest extends TestSuite { interpreter, multipleValueHeaderSupport = false, inputStreamSupport = false, - invulnerableToUnsanitizedHeaders = false, - maxContentLength = true + invulnerableToUnsanitizedHeaders = false ).tests() ++ new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(AkkaStreams)(drainAkka) ++ + new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new PlayServerWithContextTest(backend).tests() ++ new ServerWebSocketTests(createServerTest, AkkaStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 922e6973eb..543048ef90 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -28,7 +28,7 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( oneOfBody: Boolean = true, cors: Boolean = true, options: Boolean = true, - maxContentLength: Boolean = false // TODO let's work towards making this true by default + maxContentLength: Boolean = true )(implicit m: MonadError[F] ) { @@ -39,7 +39,7 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( (if (file) new ServerFileTests(createServerTest).tests() else Nil) ++ (if (mapping) new ServerMappingTests(createServerTest).tests() else Nil) ++ (if (metrics) new ServerMetricsTest(createServerTest).tests() else Nil) ++ - (if (multipart) new ServerMultipartTests(createServerTest).tests() else Nil) ++ + (if (multipart) new ServerMultipartTests(createServerTest, maxContentLengthSupport = maxContentLength).tests() else Nil) ++ (if (oneOf) new ServerOneOfTests(createServerTest).tests() else Nil) ++ (if (reject) new ServerRejectTests(createServerTest, serverInterpreter).tests() else Nil) ++ (if (staticContent) new ServerFilesTests(serverInterpreter, backend).tests() else Nil) ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 3d7f199039..000abfedb1 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -36,7 +36,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( supportsUrlEncodedPathSegments: Boolean = true, supportsMultipleSetCookieHeaders: Boolean = true, invulnerableToUnsanitizedHeaders: Boolean = true, - maxContentLength: Boolean = false + maxContentLength: Boolean = true )(implicit m: MonadError[F] ) { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala index d83d9232bc..157b27cce3 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala @@ -5,6 +5,8 @@ import org.scalatest.matchers.should.Matchers._ import sttp.client3.{multipartFile, _} import sttp.model.{Part, StatusCode} import sttp.monad.MonadError +import sttp.tapir._ +import sttp.tapir.generic.auto._ import sttp.tapir.tests.Multipart.{ in_file_list_multipart_out_multipart, in_file_multipart_out_multipart, @@ -13,8 +15,9 @@ import sttp.tapir.tests.Multipart.{ in_simple_multipart_out_string } import sttp.tapir.tests.TestUtil.{readFromFile, writeToFile} -import sttp.tapir.tests.data.{FruitAmount, FruitData} +import sttp.tapir.tests.data.{DoubleFruit, FruitAmount, FruitData} import sttp.tapir.tests.{MultipleFileUpload, Test, data} +import sttp.tapir.server.model.EndpointExtensions._ import scala.concurrent.Await import scala.concurrent.duration.DurationInt @@ -22,12 +25,39 @@ import scala.concurrent.duration.DurationInt class ServerMultipartTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], partContentTypeHeaderSupport: Boolean = true, - partOtherHeaderSupport: Boolean = true + partOtherHeaderSupport: Boolean = true, + maxContentLengthSupport: Boolean = true )(implicit m: MonadError[F]) { import createServerTest._ def tests(): List[Test] = - basicTests() ++ (if (partContentTypeHeaderSupport) contentTypeHeaderTests() else Nil) + basicTests() ++ (if (partContentTypeHeaderSupport) contentTypeHeaderTests() else Nil) ++ + (if (maxContentLengthSupport) maxContentLengthTests() else Nil) + + def maxContentLengthTests(): List[Test] = List( + testServer( + endpoint.post + .in("api" / "echo" / "multipart") + .in(multipartBody[DoubleFruit]) + .out(stringBody) + .maxRequestBodyLength(15000), + "multipart with maxContentLength" + )((df: DoubleFruit) => pureResult(("ok").asRight[Unit])) { (backend, baseUri) => + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipart("fruitA", "pineapple".repeat(1100)), multipart("fruitB", "maracuja".repeat(1200))) + .send(backend) + .map { r => + r.code shouldBe StatusCode.PayloadTooLarge + } >> basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipart("fruitA", "pineapple".repeat(850)), multipart("fruitB", "maracuja".repeat(850))) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + } + } + ) def basicTests(): List[Test] = { List( diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 2a66e8326e..6bc1b5ad67 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -16,7 +16,7 @@ import sttp.capabilities.fs2.Fs2Streams class ServerStreamingTests[F[_], S, OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], - maxLengthSupported: Boolean + maxLengthSupported: Boolean = true )(implicit m: MonadError[F] ) { diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index d3b532bea2..9a96f67fd1 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -23,13 +23,20 @@ class CatsVertxServerTest extends TestSuite { val interpreter = new CatsVertxTestServerInterpreter(vertx, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, multipart = false, reject = false, options = false).tests() ++ + new AllServerTests( + createServerTest, + interpreter, + backend, + multipart = false, + reject = false, + options = false + ).tests() ++ new ServerMultipartTests( createServerTest, partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false ).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(Fs2Streams.apply[IO])(drainFs2) ++ + new ServerStreamingTests(createServerTest).tests(Fs2Streams.apply[IO])(drainFs2) ++ new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO]) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala index e72a8ec9b5..5973b0b712 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/decoders/VertxRequestBody.scala @@ -26,6 +26,7 @@ class VertxRequestBody[F[_], S <: Streams[S]]( extends RequestBody[F, S] { override val streams: Streams[S] = readStreamCompatible.streams + // We can ignore maxBytes here, because vertx native body limit check is attached to endpoints by methods in sttp.tapir.server.vertx.handlers override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R], maxBytes: Option[Long]): F[RawValue[R]] = { val rc = routingContext(serverRequest) fromVFuture(bodyType match { diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/handlers/package.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/handlers/package.scala index 3f86d64240..82091cb853 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/handlers/package.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/handlers/package.scala @@ -6,14 +6,15 @@ import io.vertx.ext.web.handler.BodyHandler import sttp.tapir.{Endpoint, EndpointIO, EndpointOutput} import sttp.tapir.RawBodyType.MultipartBody import sttp.tapir.internal._ +import sttp.tapir.server.model.MaxContentLength package object handlers { - private[vertx] lazy val bodyHandler = BodyHandler.create(false) - - private[vertx] def multipartHandler(uploadDirectory: String): Handler[RoutingContext] = { rc => + private[vertx] def multipartHandler(uploadDirectory: String, maxBytes: Option[Long]): Handler[RoutingContext] = { rc => rc.request.setExpectMultipart(true) - bodyHandler + maxBytes + .map(BodyHandler.create(false).setBodyLimit) + .getOrElse(BodyHandler.create(false)) .setHandleFileUploads(true) .setUploadsDirectory(uploadDirectory) .handle(rc) @@ -36,12 +37,18 @@ package object handlers { case _ => Vector.empty } + val maxBytes: Option[Long] = + e.info + .attribute(MaxContentLength.attributeKey) + .map(_.value) + mbWebsocketType.headOption.orElse(bodyType.headOption) match { - case Some(MultipartBody(_, _)) => route.handler(multipartHandler(uploadDirectory)) + case Some(MultipartBody(_, _)) => route.handler(multipartHandler(uploadDirectory, maxBytes)) case Some(_: EndpointIO.StreamBodyWrapper[_, _]) => route.handler(streamPauseHandler) case Some(_: EndpointOutput.WebSocketBodyWrapper[_, _]) => route.handler(streamPauseHandler) - case Some(_) => route.handler(bodyHandler) - case None => () + case Some(_) => + route.handler(maxBytes.map(BodyHandler.create(false).setBodyLimit).getOrElse(BodyHandler.create(false))) + case None => () } route diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/LimitedReadStream.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/LimitedReadStream.scala new file mode 100644 index 0000000000..79f6e83b80 --- /dev/null +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/LimitedReadStream.scala @@ -0,0 +1,57 @@ +package sttp.tapir.server.vertx.streams + +import io.vertx.core.streams.ReadStream +import io.vertx.core.Handler +import io.vertx.core.buffer.Buffer +import sttp.capabilities.StreamMaxLengthExceededException + +/** + * An adapter for Vertx ReadStream[Buffer], which passes bytes through, but fails with a [[sttp.capabilities.StreamMaxLengthExceededException]] if exceeds given limit. + * This exception should be handled by [[sttp.tapir.server.interceptor.exception.DefaultExceptionHandler]] in order to return a HTTP 413 Payload Too Large. + + */ +private[vertx] class LimitedReadStream(source: ReadStream[Buffer], maxBytes: Long) extends ReadStream[Buffer] { + + // Safe, Vertx uses a single thread + private var bytesReadSoFar: Long = 0 + private var endHandler: Handler[Void] = _ + private var exceptionHandler: Handler[Throwable] = _ + private var dataHandler: Handler[Buffer] = _ + + override def handler(handler: Handler[Buffer]): ReadStream[Buffer] = { + dataHandler = (buffer: Buffer) => { + bytesReadSoFar += buffer.length() + if (bytesReadSoFar > maxBytes) { + if (exceptionHandler != null) { + exceptionHandler.handle(new StreamMaxLengthExceededException(maxBytes)) + } + } else { + handler.handle(buffer) + } + } + source.handler(dataHandler) + this + } + + override def exceptionHandler(handler: Handler[Throwable]): ReadStream[Buffer] = { + this.exceptionHandler = handler + source.exceptionHandler(handler) + this + } + override def pause(): ReadStream[Buffer] = { + source.pause() + this + } + override def resume(): ReadStream[Buffer] = { + fetch(Long.MaxValue) + } + override def fetch(amount: Long): ReadStream[Buffer] = { + source.fetch(amount) + this + } + override def endHandler(endHandler: Handler[Void]): ReadStream[Buffer] = { + this.endHandler = endHandler + source.endHandler(endHandler) + this + } +} diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala index e9197ac7b8..bbd928c0f5 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/package.scala @@ -18,8 +18,8 @@ package object streams { override def asReadStream(readStream: ReadStream[Buffer]): ReadStream[Buffer] = readStream - override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): ReadStream[Buffer] = // TODO support maxBytes - readStream + override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): ReadStream[Buffer] = + maxBytes.map(new LimitedReadStream(readStream, _)).getOrElse(readStream) override def webSocketPipe[REQ, RESP]( readStream: ReadStream[WebSocketFrame], diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index 8ffe8f51b4..8a7a8f57ce 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -10,6 +10,8 @@ import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.ExecutionContext import scala.concurrent.Future +import scala.concurrent.Promise +import io.vertx.core.buffer.Buffer class VertxServerTest extends TestSuite { def vertxResource: Resource[IO, Vertx] = @@ -22,12 +24,35 @@ class VertxServerTest extends TestSuite { val interpreter = new VertxTestServerInterpreter(vertx) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, multipart = false, reject = false, options = false).tests() ++ + def drainVertx[T](source: ReadStream[T]): Future[Unit] = { + val p = Promise[Unit]() + // Handler for stream data - do nothing with the data + val dataHandler: Handler[T] = (_: T) => () + + // End handler - complete the promise when the stream ends + val endHandler: Handler[Void] = (_: Void) => p.success(()) + + // Exception handler - fail the promise if an error occurs + val exceptionHandler: Handler[Throwable] = (t: Throwable) => p.failure(t) + + source.handler(dataHandler).endHandler(endHandler).exceptionHandler(exceptionHandler).fetch(Long.MaxValue).resume() + + p.future + } + + new AllServerTests( + createServerTest, + interpreter, + backend, + multipart = false, + reject = false, + options = false + ).tests() ++ new ServerMultipartTests( createServerTest, partContentTypeHeaderSupport = true, partOtherHeaderSupport = false - ).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(VertxStreams)(_ => Future.unit) ++ + ).tests() ++ new ServerStreamingTests(createServerTest).tests(VertxStreams)(drainVertx[Buffer]) ++ (new ServerWebSocketTests(createServerTest, VertxStreams) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index ea448196fd..0a8c27f4e4 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -42,7 +42,7 @@ class ZioVertxServerTest extends TestSuite with OptionValues { partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong partOtherHeaderSupport = false ).tests() ++ additionalTests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ + new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerInterpreter.scala b/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerInterpreter.scala deleted file mode 100644 index c799c8652d..0000000000 --- a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerInterpreter.scala +++ /dev/null @@ -1,139 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import io.vertx.core.{Future, Handler, Promise} -import io.vertx.ext.web.{Route, Router, RoutingContext} -import sttp.capabilities.WebSockets -import sttp.capabilities.zio.ZioStreams -import sttp.tapir.server.interceptor.RequestResult -import sttp.tapir.server.interpreter.{BodyListener, ServerInterpreter} -import sttp.tapir.server.vertx.VertxErrorHandler -import sttp.tapir.server.vertx.zio.VertxZioServerInterpreter.{RioFromVFuture, VertxFutureToRIO, ZioRunAsync} -import sttp.tapir.server.vertx.decoders.{VertxRequestBody, VertxServerRequest} -import sttp.tapir.server.vertx.encoders.{VertxOutputEncoders, VertxToResponseBody} -import sttp.tapir.server.vertx.interpreters.{CommonServerInterpreter, FromVFuture, RunAsync} -import sttp.tapir.server.vertx.routing.PathMapping.extractRouteDefinition -import sttp.tapir.server.vertx.zio.streams._ -import sttp.tapir.server.vertx.VertxBodyListener -import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} -import _root_.zio._ -import _root_.zio.blocking.Blocking - -import java.util.concurrent.atomic.AtomicReference - -trait VertxZioServerInterpreter[R <: Blocking] extends CommonServerInterpreter with VertxErrorHandler { - def vertxZioServerOptions: VertxZioServerOptions[R] = VertxZioServerOptions.default - - def route(e: ZServerEndpoint[R, ZioStreams with WebSockets])(implicit - runtime: Runtime[R] - ): Router => Route = { router => - mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxZioServerOptions) - .handler(endpointHandler(e)) - } - - private def endpointHandler( - e: ZServerEndpoint[R, ZioStreams with WebSockets] - )(implicit runtime: Runtime[R]): Handler[RoutingContext] = { - val fromVFuture = new RioFromVFuture[R] - implicit val monadError: RIOMonadError[R] = new RIOMonadError[R] - implicit val bodyListener: BodyListener[RIO[R, *], RoutingContext => Future[Void]] = - new VertxBodyListener[RIO[R, *]](new ZioRunAsync(runtime)) - val zioReadStream = zioReadStreamCompatible(vertxZioServerOptions) - val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R, *], RoutingContext => Future[Void], ZioStreams]( - _ => List(e), - new VertxRequestBody[RIO[R, *], ZioStreams](vertxZioServerOptions, fromVFuture)(zioReadStream), - new VertxToResponseBody(vertxZioServerOptions)(zioReadStream), - vertxZioServerOptions.interceptors, - vertxZioServerOptions.deleteFile - ) - - new Handler[RoutingContext] { - override def handle(rc: RoutingContext) = { - val serverRequest = VertxServerRequest(rc) - - val result: ZIO[R, Throwable, Any] = - interpreter(serverRequest) - .flatMap { - // in vertx, endpoints are attempted to be decoded individually; if this endpoint didn't match - another one might - case RequestResult.Failure(_) => ZIO.succeed(rc.next()) - case RequestResult.Response(response) => - Task.effectAsync((k: Task[Unit] => Unit) => { - VertxOutputEncoders(response) - .apply(rc) - .onComplete(d => { - if (d.succeeded()) k(Task.unit) else k(Task.fail(d.cause())) - }) - }) - } - .catchAll { t => handleError(rc, t).asRIO } - - // we obtain the cancel token only after the effect is run, so we need to pass it to the exception handler - // via a mutable ref; however, before this is done, it's possible an exception has already been reported; - // if so, we need to use this fact to cancel the operation nonetheless - val cancelRef = new AtomicReference[Option[Either[Throwable, Fiber.Id => Exit[Throwable, Any]]]](None) - - rc.response.exceptionHandler { (t: Throwable) => - cancelRef.getAndSet(Some(Left(t))).collect { case Right(c) => - rc.vertx() - .executeBlocking[Unit]( - (promise: Promise[Unit]) => { - c(Fiber.Id.None) - promise.complete(()) - }, - false - ) - } - () - } - - val canceler = runtime.unsafeRunAsyncCancelable(result.catchAll { t => handleError(rc, t).asRIO }) { - case Exit.Failure(_) => () // should be handled - case Exit.Success(_) => () - } - cancelRef.getAndSet(Some(Right(canceler))).collect { case Left(_) => - rc.vertx() - .executeBlocking[Unit]( - (promise: Promise[Unit]) => { - canceler(Fiber.Id.None) - promise.complete(()) - }, - false - ) - } - - () - } - } - } -} - -object VertxZioServerInterpreter { - def apply[R <: Blocking]( - serverOptions: VertxZioServerOptions[R] = VertxZioServerOptions.default - ): VertxZioServerInterpreter[R] = { - new VertxZioServerInterpreter[R] { - override def vertxZioServerOptions: VertxZioServerOptions[R] = serverOptions - } - } - - private[vertx] class RioFromVFuture[R] extends FromVFuture[RIO[R, *]] { - def apply[T](f: => Future[T]): RIO[R, T] = f.asRIO - } - - private[vertx] class ZioRunAsync[R](runtime: Runtime[R]) extends RunAsync[RIO[R, *]] { - override def apply[T](f: => RIO[R, T]): Unit = runtime.unsafeRunAsync(f)(_ => ()) - } - - implicit class VertxFutureToRIO[A](f: => Future[A]) { - def asRIO[R]: RIO[R, A] = { - RIO.effectAsync { cb => - f.onComplete { handler => - if (handler.succeeded()) { - cb(Task.succeed(handler.result())) - } else { - cb(Task.fail(handler.cause())) - } - } - } - } - } -} diff --git a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerOptions.scala b/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerOptions.scala deleted file mode 100644 index 8dc8fd4349..0000000000 --- a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/VertxZioServerOptions.scala +++ /dev/null @@ -1,58 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import io.vertx.core.logging.{Logger, LoggerFactory} -import sttp.tapir.server.interceptor.log.{DefaultServerLog, ServerLog} -import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor} -import sttp.tapir.{Defaults, TapirFile} -import _root_.zio.{RIO, URIO} -import _root_.zio.blocking._ -import sttp.tapir.server.vertx.VertxServerOptions - -final case class VertxZioServerOptions[R]( - uploadDirectory: TapirFile, - deleteFile: TapirFile => RIO[R, Unit], - maxQueueSizeForReadStream: Int, - interceptors: List[Interceptor[RIO[R, *]]] -) extends VertxServerOptions[RIO[R, *]] { - def prependInterceptor(i: Interceptor[RIO[R, *]]): VertxZioServerOptions[R] = - copy(interceptors = i :: interceptors) - def appendInterceptor(i: Interceptor[RIO[R, *]]): VertxZioServerOptions[R] = - copy(interceptors = interceptors :+ i) - - def widen[R2 <: R]: VertxZioServerOptions[R2] = this.asInstanceOf[VertxZioServerOptions[R2]] -} - -object VertxZioServerOptions { - - /** Allows customising the interceptors used by the server interpreter. */ - def customiseInterceptors[R <: Blocking]: CustomiseInterceptors[RIO[R, *], VertxZioServerOptions[R]] = - CustomiseInterceptors( - createOptions = (ci: CustomiseInterceptors[RIO[R, *], VertxZioServerOptions[R]]) => - VertxZioServerOptions( - VertxServerOptions.uploadDirectory(), - file => effectBlocking(Defaults.deleteFile()(file)), - maxQueueSizeForReadStream = 16, - ci.interceptors - ) - ).serverLog(defaultServerLog[R](LoggerFactory.getLogger("tapir-vertx"))) - - implicit def default[R <: Blocking]: VertxZioServerOptions[R] = customiseInterceptors.options - - def defaultServerLog[R](log: Logger): DefaultServerLog[RIO[R, *]] = { - DefaultServerLog( - doLogWhenReceived = debugLog(log)(_, None), - doLogWhenHandled = debugLog(log), - doLogAllDecodeFailures = infoLog(log), - doLogExceptions = (msg: String, ex: Throwable) => URIO.succeed { log.error(msg, ex) }, - noLog = URIO.unit - ) - } - - private def debugLog[R](log: Logger)(msg: String, exOpt: Option[Throwable]): RIO[R, Unit] = URIO.succeed { - VertxServerOptions.debugLog(log)(msg, exOpt) - } - - private def infoLog[R](log: Logger)(msg: String, exOpt: Option[Throwable]): RIO[R, Unit] = URIO.succeed { - VertxServerOptions.infoLog(log)(msg, exOpt) - } -} diff --git a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala b/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala deleted file mode 100644 index 87e41466d8..0000000000 --- a/server/vertx-server/zio1/src/main/scala/sttp/tapir/server/vertx/zio/streams/zio.scala +++ /dev/null @@ -1,242 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import _root_.zio._ -import _root_.zio.stream.{Stream, ZStream} -import io.vertx.core.Handler -import io.vertx.core.buffer.Buffer -import io.vertx.core.streams.ReadStream -import sttp.capabilities.zio.ZioStreams -import sttp.tapir.model.WebSocketFrameDecodeFailure -import sttp.tapir.server.vertx.streams.ReadStreamState._ -import sttp.tapir.server.vertx.streams._ -import sttp.tapir.server.vertx.streams.websocket._ -import sttp.tapir.{DecodeResult, WebSocketBodyOutput} -import sttp.ws.WebSocketFrame -import zio.clock.Clock - -import scala.collection.immutable.{Queue => SQueue} -import scala.language.postfixOps - -package object streams { - - implicit class DeferredOps[A](dfd: Promise[Nothing, A]) extends DeferredLike[UIO, A] { - override def complete(a: A): UIO[Unit] = - dfd.done(Exit.Success(a)).unit - - override def get: UIO[A] = - dfd.await - } - - def zioReadStreamCompatible[R](opts: VertxZioServerOptions[R])(implicit - runtime: Runtime[Any] - ): ReadStreamCompatible[ZioStreams] = new ReadStreamCompatible[ZioStreams] { - - override val streams: ZioStreams = ZioStreams - - override def asReadStream(stream: Stream[Throwable, Byte]): ReadStream[Buffer] = - mapToReadStream[Chunk[Byte], Buffer]( - stream.mapChunks(chunk => Chunk.single(chunk)), - chunk => Buffer.buffer(chunk.toArray) - ) - - private def mapToReadStream[I, O](stream: Stream[Throwable, I], fn: I => O): ReadStream[O] = - runtime - .unsafeRunSync(for { - promise <- Promise.make[Nothing, Unit] - state <- Ref.make(StreamState.empty[UIO, O](promise)) - _ <- stream.foreach { chunk => - val buffer = fn(chunk) - state.get.flatMap { - case StreamState(None, handler, _, _) => - ZIO.effect(handler.handle(buffer)) - case StreamState(Some(promise), _, _, _) => - for { - _ <- promise.get - // Handler in state may be updated since the moment when we wait - // promise so let's get more recent version. - updatedState <- state.get - } yield updatedState.handler.handle(buffer) - } - } onExit { - case Exit.Success(()) => - state.get.flatMap { state => - ZIO - .effect(state.endHandler.handle(null)) - .catchAll(cause2 => ZIO.effect(state.errorHandler.handle(cause2)).either) - } - case Exit.Failure(cause) => - state.get.flatMap { state => - ZIO - .effect(state.errorHandler.handle(cause.squash)) - .catchAll(cause2 => ZIO.effect(state.errorHandler.handle(cause2)).either) - } - } forkDaemon - } yield new ReadStream[O] { self => - override def handler(handler: Handler[O]): ReadStream[O] = - runtime - .unsafeRunSync(state.update(_.copy(handler = handler))) - .toEither - .fold(throw _, _ => self) - - override def exceptionHandler(handler: Handler[Throwable]): ReadStream[O] = - runtime - .unsafeRunSync(state.update(_.copy(errorHandler = handler))) - .toEither - .fold(throw _, _ => self) - - override def endHandler(handler: Handler[Void]): ReadStream[O] = - runtime - .unsafeRunSync(state.update(_.copy(endHandler = handler))) - .toEither - .fold(throw _, _ => self) - - override def pause(): ReadStream[O] = - runtime - .unsafeRunSync(for { - promise <- Promise.make[Nothing, Unit] - _ <- state.update { - case cur @ StreamState(Some(_), _, _, _) => - cur - case cur @ StreamState(None, _, _, _) => - cur.copy(paused = Some(promise)) - } - } yield self) - .toEither - .fold(throw _, identity) - - override def resume(): ReadStream[O] = - runtime - .unsafeRunSync(for { - oldState <- state.getAndUpdate(_.copy(paused = None)) - _ <- oldState.paused.fold[UIO[Any]](UIO.unit)(_.complete(())) - } yield self) - .toEither - .fold(throw _, identity) - - override def fetch(x: Long): ReadStream[O] = - self - }) - .toEither - .fold(throw _, identity) - - override def fromReadStream(readStream: ReadStream[Buffer], maxBytes: Option[Long]): Stream[Throwable, Byte] = { - fromReadStreamInternal(readStream).mapConcatChunk(buffer => Chunk.fromArray(buffer.getBytes)) - } - - private def fromReadStreamInternal[T](readStream: ReadStream[T]): Stream[Throwable, T] = - runtime - .unsafeRunSync(for { - stateRef <- Ref.make(ReadStreamState[UIO, T](Queued(SQueue.empty), Queued(SQueue.empty))) - stream = ZStream.unfoldM(()) { _ => - for { - dfd <- Promise.make[Nothing, WrappedBuffer[T]] - tuple <- stateRef.modify(_.dequeueBuffer(dfd).swap) - (mbBuffer, mbAction) = tuple - _ <- ZIO.foreach(mbAction)(identity) - wrappedBuffer <- mbBuffer match { - case Left(deferred) => - deferred.get - case Right(buffer) => - UIO.succeed(buffer) - } - result <- wrappedBuffer match { - case Right(buffer) => UIO.some((buffer, ())) - case Left(None) => UIO.none - case Left(Some(cause)) => IO.fail(cause) - } - } yield result - } - _ <- ZStream - .unfoldM(())({ _ => - for { - dfd <- Promise.make[Nothing, WrappedEvent] - mbEvent <- stateRef.modify(_.dequeueActivationEvent(dfd).swap) - result <- mbEvent match { - case Left(deferred) => - deferred.get - case Right(event) => - UIO.succeed(event) - } - } yield result.map((_, ())) - }) - .mapM({ - case Pause => - IO.effect(readStream.pause()) - case Resume => - IO.effect(readStream.resume()) - }) - .runDrain - .forkDaemon - } yield { - readStream.endHandler { _ => - runtime - .unsafeRunSync(stateRef.modify(_.halt(None).swap).flatMap(ZIO.foreach_(_)(identity))) - .fold(c => throw c.squash, identity) - } - readStream.exceptionHandler { cause => - runtime - .unsafeRunSync(stateRef.modify(_.halt(Some(cause)).swap).flatMap(ZIO.foreach_(_)(identity))) - .fold(c => throw c.squash, identity) - } - readStream.handler { buffer => - val maxSize = opts.maxQueueSizeForReadStream - runtime - .unsafeRunSync(stateRef.modify(_.enqueue(buffer, maxSize).swap).flatMap(ZIO.foreach_(_)(identity))) - .fold(c => throw c.squash, identity) - } - - stream - }) - .toEither - .fold(throw _, identity) - - override def webSocketPipe[REQ, RESP]( - readStream: ReadStream[WebSocketFrame], - pipe: streams.Pipe[REQ, RESP], - o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] - ): ReadStream[WebSocketFrame] = { - val stream0 = fromReadStreamInternal(readStream) - val stream1 = optionallyContatenateFrames(stream0, o.concatenateFragmentedFrames) - val stream2 = optionallyIgnorePong(stream1, o.ignorePong) - val autoPings = o.autoPing match { - case Some((interval, frame)) => - ZStream.tick(duration.Duration.fromScala(interval)).as(frame).provideLayer(Clock.live) - case None => - ZStream.empty - } - - val stream3 = stream2 - .mapM { frame => - o.requests.decode(frame) match { - case failure: DecodeResult.Failure => - ZIO.fail(new WebSocketFrameDecodeFailure(frame, failure)) - case DecodeResult.Value(v) => - ZIO.succeed(v) - } - } - - val stream4 = pipe(stream3) - .map(o.responses.encode) - .mergeTerminateLeft(autoPings) - .concat(ZStream(WebSocketFrame.close)) - - mapToReadStream[WebSocketFrame, WebSocketFrame](stream4, identity) - } - - private def optionallyContatenateFrames( - s: Stream[Throwable, WebSocketFrame], - doConcatenate: Boolean - ): Stream[Throwable, WebSocketFrame] = - if (doConcatenate) { - s.mapAccum(None: Accumulator)(concatenateFrames).collect { case Some(f) => f } - } else { - s - } - - private def optionallyIgnorePong( - s: Stream[Throwable, WebSocketFrame], - ignore: Boolean - ): Stream[Throwable, WebSocketFrame] = - if (ignore) s.filterNot(isPong) else s - } -} diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/VertxStubServerTest.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/VertxStubServerTest.scala deleted file mode 100644 index 64185017a3..0000000000 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/VertxStubServerTest.scala +++ /dev/null @@ -1,27 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import sttp.capabilities.zio.ZioStreams -import sttp.client3.testing.SttpBackendStub -import sttp.tapir.server.interceptor.CustomiseInterceptors -import sttp.tapir.server.tests.{CreateServerStubTest, ServerStubStreamingTest, ServerStubTest} -import _root_.zio.stream.ZStream -import _root_.zio.{RIO, Runtime} -import _root_.zio.blocking.Blocking -import sttp.tapir.ztapir.RIOMonadError - -import scala.concurrent.Future - -object VertxZioCreateServerStubTest extends CreateServerStubTest[RIO[Blocking, *], VertxZioServerOptions[Blocking]] { - override def customiseInterceptors: CustomiseInterceptors[RIO[Blocking, *], VertxZioServerOptions[Blocking]] = - VertxZioServerOptions.customiseInterceptors - override def stub[R]: SttpBackendStub[RIO[Blocking, *], R] = SttpBackendStub(new RIOMonadError[Blocking]) - override def asFuture[A]: RIO[Blocking, A] => Future[A] = task => Runtime.default.unsafeRunToFuture(task) -} - -class VertxZioServerStubTest extends ServerStubTest(VertxZioCreateServerStubTest) - -class VertxZioServerStubStreamingTest extends ServerStubStreamingTest(VertxZioCreateServerStubTest, ZioStreams) { - - /** Must be an instance of streams.BinaryStream */ - override def sampleStream: Any = ZStream.fromIterable(List("hello")) -} diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala deleted file mode 100644 index dbe692745e..0000000000 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ /dev/null @@ -1,38 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import cats.effect.{IO, Resource} -import io.vertx.core.Vertx -import sttp.capabilities.zio.ZioStreams -import sttp.monad.MonadError -import sttp.tapir.server.tests._ -import sttp.tapir.tests.{Test, TestSuite} -import _root_.zio.RIO -import _root_.zio.blocking.Blocking -import sttp.tapir.ztapir.RIOMonadError -import zio.stream.ZStream -import zio.Task - -class ZioVertxServerTest extends TestSuite { - def vertxResource: Resource[IO, Vertx] = - Resource.make(IO.delay(Vertx.vertx()))(vertx => IO.delay(vertx.close()).void) - - override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => - vertxResource.map { implicit vertx => - implicit val m: MonadError[RIO[Blocking, *]] = new RIOMonadError[Blocking] - val interpreter = new ZioVertxTestServerInterpreter(vertx) - val createServerTest = new DefaultCreateServerTest(backend, interpreter) - - new AllServerTests(createServerTest, interpreter, backend, multipart = false, reject = false, options = false).tests() ++ - new ServerMultipartTests( - createServerTest, - partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong - partOtherHeaderSupport = false - ).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ZioStreams)(_ => Task.unit) ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { - override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) - override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty - }.tests() - } - } -} diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxTestServerInterpreter.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxTestServerInterpreter.scala deleted file mode 100644 index 3c59e6eae0..0000000000 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxTestServerInterpreter.scala +++ /dev/null @@ -1,48 +0,0 @@ -package sttp.tapir.server.vertx.zio - -import cats.data.NonEmptyList -import cats.effect.{IO, Resource} -import io.vertx.core.Vertx -import io.vertx.core.http.HttpServerOptions -import io.vertx.ext.web.{Route, Router} -import sttp.capabilities.zio.ZioStreams -import sttp.capabilities.WebSockets -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.tests.TestServerInterpreter -import sttp.tapir.tests._ -import _root_.zio.{Has, RIO, Runtime} -import _root_.zio.blocking.Blocking -import sttp.tapir.server.vertx.VertxTestServerInterpreter -import scala.concurrent.duration.FiniteDuration - -class ZioVertxTestServerInterpreter(vertx: Vertx) - extends TestServerInterpreter[RIO[Blocking, *], ZioStreams with WebSockets, VertxZioServerOptions[Blocking], Router => Route] { - import ZioVertxTestServerInterpreter._ - - override def route( - es: List[ServerEndpoint[ZioStreams with WebSockets, RIO[Blocking, *]]], - interceptors: Interceptors - ): Router => Route = { router => - val options: VertxZioServerOptions[Blocking] = interceptors(VertxZioServerOptions.customiseInterceptors).options - val interpreter = VertxZioServerInterpreter(options) - es.map(interpreter.route(_)(runtime)(router)).last - } - - override def serverWithStop( - routes: NonEmptyList[Router => Route], - gracefulShutdownTimeout: Option[FiniteDuration] - ): Resource[IO, (Port, KillSwitch)] = { - val router = Router.router(vertx) - val server = vertx.createHttpServer(new HttpServerOptions().setPort(0)).requestHandler(router) - routes.toList.foreach(_.apply(router)) - val listenIO = VertxTestServerInterpreter.vertxFutureToIo(server.listen(0)) - // Vertx doesn't offer graceful shutdown with timeout OOTB - Resource.make(listenIO.map(s => (s.actualPort(), VertxTestServerInterpreter.vertxFutureToIo(s.close).void))) { case (_, release) => - release - } - } -} - -object ZioVertxTestServerInterpreter { - implicit val runtime: Runtime[Blocking] = Runtime.default.as(Has(Blocking.Service.live)) -} diff --git a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala b/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala deleted file mode 100644 index fafa0daffc..0000000000 --- a/server/vertx-server/zio1/src/test/scala/sttp/tapir/server/vertx/zio/streams/ZStreamTest.scala +++ /dev/null @@ -1,225 +0,0 @@ -package sttp.tapir.server.vertx.zio.streams - -import java.nio.ByteBuffer -import io.vertx.core.buffer.Buffer -import org.scalatest.flatspec.AsyncFlatSpec -import org.scalatest.matchers.should.Matchers -import _root_.zio._ -import _root_.zio.blocking.Blocking -import _root_.zio.duration._ -import _root_.zio.stream.ZStream -import _root_.zio.clock.Clock -import sttp.capabilities.zio.ZioStreams -import _root_.zio.{Runtime => ZIORuntime} -import sttp.tapir.server.vertx.zio.VertxZioServerOptions -import sttp.tapir.server.vertx.streams.FakeStream - -class ZStreamTest extends AsyncFlatSpec with Matchers { - - private val runtime = ZIORuntime.default - - private val options = VertxZioServerOptions.default[Blocking].copy(maxQueueSizeForReadStream = 4) - - def intAsBuffer(int: Int): Chunk[Byte] = { - val buffer = ByteBuffer.allocate(4) - buffer.putInt(int) - buffer.flip() - Chunk.fromByteBuffer(buffer) - } - - def intAsVertxBuffer(int: Int): Buffer = - Buffer.buffer(intAsBuffer(int).toArray) - - def bufferAsInt(buffer: Buffer): Int = { - val bs = buffer.getBytes() - (bs(0) & 0xff) << 24 | (bs(1) & 0xff) << 16 | (bs(2) & 0xff) << 8 | (bs(3) & 0xff) - } - - def chunkAsInt(chunk: Chunk[Byte]): Int = - bufferAsInt(Buffer.buffer(chunk.toArray)) - - def shouldIncreaseMonotonously(xs: List[Int]): Unit = { - all(xs.iterator.sliding(2).map(_.toList).toList) should matchPattern { - case (first: Int) :: (second: Int) :: Nil if first + 1 == second => - } - () - } - - val schedule = (Schedule.spaced(50.millis) >>> Schedule.elapsed).whileOutput(_ < 15.seconds) - - def eventually[A](task: Task[A])(cond: PartialFunction[A, Unit]): Task[A] = - task.tap(a => ZIO.effect(cond(a))).retry(schedule).provideLayer(Clock.live) - - "ZioReadStreamCompatible" should "convert zio stream to read stream" in { - val stream: ZioStreams.BinaryStream = ZStream - .tick(100.millis) - .mapAccum(0)((acc, _) => (acc + 1, acc)) - .haltAfter(3.seconds) - .map(intAsBuffer) - .flattenChunks - .provideLayer(Clock.live) - val readStream = zioReadStreamCompatible(options)(runtime).asReadStream(stream) - runtime - .unsafeRunToFuture(for { - ref <- Ref.make[List[Int]](Nil) - completed <- Ref.make[Boolean](false) - _ <- Task.effect { - readStream.handler { buffer => - runtime.unsafeRunSync(ref.update(_ :+ bufferAsInt(buffer))) - () - } - } - _ <- Task.effect { - readStream.endHandler { _ => - runtime.unsafeRunSync(completed.set(true)) - () - } - } - _ <- Task.effect(readStream.resume()) - _ <- eventually(ref.get)({ case _ :: _ => () }) - _ <- Task.effect(readStream.pause()) - _ <- ZIO.sleep(1.seconds) - snapshot2 <- ref.get - _ <- Task.effect(readStream.resume()) - snapshot3 <- eventually(ref.get)({ case list => list.length should be > snapshot2.length }) - _ = shouldIncreaseMonotonously(snapshot3) - _ <- eventually(completed.get)({ case true => () }) - } yield succeed) - } - - it should "interrupt read stream after zio stream interruption" in { - val stream = ZStream - .tick(100.millis) - .mapAccum(0)((acc, _) => (acc + 1, acc)) - .haltAfter(7.seconds) - .map(intAsBuffer) - .flattenChunks - .provideLayer(Clock.live) ++ ZStream.fail(new Exception("!")) - val readStream = zioReadStreamCompatible(options)(runtime).asReadStream(stream) - runtime - .unsafeRunToFuture(for { - ref <- Ref.make[List[Int]](Nil) - completedRef <- Ref.make[Boolean](false) - interruptedRef <- Ref.make[Option[Throwable]](None) - _ <- Task.effect { - readStream.handler { buffer => - runtime.unsafeRunSync(ref.update(_ :+ bufferAsInt(buffer))) - () - } - } - _ <- Task.effect { - readStream.endHandler { _ => - runtime.unsafeRunSync(completedRef.set(true)) - () - } - } - _ <- Task.effect { - readStream.exceptionHandler { cause => - runtime.unsafeRunSync(interruptedRef.set(Some(cause))) - () - } - } - _ <- Task.effect(readStream.resume()) - snapshot <- eventually(ref.get)({ case list => list.length should be > 3 }) - _ = shouldIncreaseMonotonously(snapshot) - _ <- eventually(completedRef.get zip interruptedRef.get)({ case (false, Some(_)) => - }) - } yield succeed) - } - - it should "drain read stream without pauses if buffer has enough space" in { - val opts = options.copy(maxQueueSizeForReadStream = 128) - val count = 100 - val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) - runtime - .unsafeRunToFuture(for { - resultFiber <- stream - .mapChunks((chunkAsInt _).andThen(Chunk.single)) - .toIterator - .map(_.toList) - .useNow - .fork - _ <- ZIO.effect { - (1 to count).foreach { i => - readStream.handle(intAsVertxBuffer(i)) - } - readStream.end() - } - result <- resultFiber.join - } yield { - val successes = result.collect { case Right(i) => i } - shouldIncreaseMonotonously(successes) - successes should have size count.toLong - readStream.pauseCount shouldBe 0 - // readStream.resumeCount shouldBe 0 - }) - } - - it should "drain read stream with small buffer" in { - val opts = options.copy(maxQueueSizeForReadStream = 4) - val count = 100 - val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) - runtime - .unsafeRunToFuture(for { - resultFiber <- stream - .mapChunks((chunkAsInt _).andThen(Chunk.single)) - .mapM(i => ZIO.sleep(50.millis).as(i)) - .toIterator - .map(_.toList) - .useNow - .fork - _ <- ZIO - .effect({ - (1 to count).foreach { i => - Thread.sleep(25) - readStream.handle(intAsVertxBuffer(i)) - } - readStream.end() - }) - .fork - result <- resultFiber.join - } yield { - val successes = result.collect { case Right(i) => i } - shouldIncreaseMonotonously(successes) - successes should have size count.toLong - readStream.pauseCount should be > 0 - readStream.resumeCount should be > 0 - }) - } - - it should "drain failed read stream" in { - val opts = options.copy(maxQueueSizeForReadStream = 4) - val count = 50 - val readStream = new FakeStream() - val stream = zioReadStreamCompatible(opts)(runtime).fromReadStream(readStream, maxBytes = None) - runtime - .unsafeRunToFuture(for { - resultFiber <- stream - .mapChunks((chunkAsInt _).andThen(Chunk.single)) - .mapM(i => ZIO.sleep(50.millis).as(i)) - .toIterator - .map(_.toList) - .useNow - .fork - _ <- ZIO - .effect({ - (1 to count).foreach { i => - Thread.sleep(25) - readStream.handle(intAsVertxBuffer(i)) - } - readStream.fail(new Exception("!")) - }) - .fork - result <- resultFiber.join - } yield { - val successes = result.collect { case Right(i) => i } - shouldIncreaseMonotonously(successes) - successes should have size count.toLong - readStream.pauseCount should be > 0 - readStream.resumeCount should be > 0 - result.lastOption.collect { case Left(e) => e } should not be empty - }) - } -} diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 09a4554bad..d42c139fcc 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -251,7 +251,8 @@ class ZioHttpServerTest extends TestSuite { interpreter, multipleValueHeaderSupport = false, supportsMultipleSetCookieHeaders = false, - invulnerableToUnsanitizedHeaders = false + invulnerableToUnsanitizedHeaders = false, + maxContentLength = false ).tests() ++ // TODO: re-enable static content once a newer zio http is available. Currently these tests often fail with: // Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE] @@ -265,7 +266,7 @@ class ZioHttpServerTest extends TestSuite { file = false, options = false ).tests() ++ - new ServerStreamingTests(createServerTest, maxLengthSupported = true).tests(ZioStreams)(drainZStream) ++ + new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) diff --git a/serverless/aws/lambda-cats-effect-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaStubHttpTest.scala b/serverless/aws/lambda-cats-effect-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaStubHttpTest.scala index 8d5f159420..a36bf71a0c 100644 --- a/serverless/aws/lambda-cats-effect-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaStubHttpTest.scala +++ b/serverless/aws/lambda-cats-effect-tests/src/test/scala/sttp/tapir/serverless/aws/lambda/tests/AwsLambdaStubHttpTest.scala @@ -14,7 +14,7 @@ class AwsLambdaStubHttpTest extends TestSuite { override def tests: Resource[IO, List[Test]] = Resource.eval( IO.pure { val createTestServer = new AwsLambdaCreateServerStubTest - new ServerBasicTests(createTestServer, AwsLambdaStubHttpTest.testServerInterpreter)(catsMonadIO).tests() ++ + new ServerBasicTests(createTestServer, AwsLambdaStubHttpTest.testServerInterpreter, maxContentLength = false)(catsMonadIO).tests() ++ new ServerMetricsTest(createTestServer).tests() } ) diff --git a/serverless/aws/lambda-zio-tests/src/test/scala/sttp/tapir/serverless/aws/ziolambda/tests/AwsLambdaStubHttpTest.scala b/serverless/aws/lambda-zio-tests/src/test/scala/sttp/tapir/serverless/aws/ziolambda/tests/AwsLambdaStubHttpTest.scala index e69b76cecb..df78db8f10 100644 --- a/serverless/aws/lambda-zio-tests/src/test/scala/sttp/tapir/serverless/aws/ziolambda/tests/AwsLambdaStubHttpTest.scala +++ b/serverless/aws/lambda-zio-tests/src/test/scala/sttp/tapir/serverless/aws/ziolambda/tests/AwsLambdaStubHttpTest.scala @@ -18,7 +18,7 @@ class AwsLambdaStubHttpTest extends TestSuite { import AwsLambdaStubHttpTest.m val createTestServer = new AwsLambdaCreateServerStubTest - new ServerBasicTests(createTestServer, AwsLambdaStubHttpTest.testServerInterpreter).tests() ++ + new ServerBasicTests(createTestServer, AwsLambdaStubHttpTest.testServerInterpreter, maxContentLength = false).tests() ++ new ServerMetricsTest(createTestServer).tests() } ) diff --git a/tests/src/main/scala/sttp/tapir/tests/data/FruitAmount.scala b/tests/src/main/scala/sttp/tapir/tests/data/FruitAmount.scala index 467681fd23..4556345e29 100644 --- a/tests/src/main/scala/sttp/tapir/tests/data/FruitAmount.scala +++ b/tests/src/main/scala/sttp/tapir/tests/data/FruitAmount.scala @@ -2,6 +2,8 @@ package sttp.tapir.tests.data case class FruitAmount(fruit: String, amount: Int) +case class DoubleFruit(fruitA: String, fruitB: String) + case class IntWrapper(v: Int) extends AnyVal case class StringWrapper(v: String) extends AnyVal