diff --git a/build.sbt b/build.sbt index c17c70412d..255ab4f138 100644 --- a/build.sbt +++ b/build.sbt @@ -1353,7 +1353,10 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve .settings(commonJvmSettings) .settings( name := "tapir-netty-server", - libraryDependencies ++= Seq("io.netty" % "netty-all" % Versions.nettyAll) + libraryDependencies ++= Seq( + "io.netty" % "netty-all" % Versions.nettyAll, + "com.typesafe.netty" % "netty-reactive-streams-http" % Versions.nettyReactiveStreams + ) ++ loggerDependencies, // needed because of https://github.com/coursier/coursier/issues/2016 useCoursier := false @@ -1362,7 +1365,10 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve .dependsOn(serverCore, serverTests % Test) lazy val nettyServerCats: ProjectMatrix = nettyServerProject("cats", catsEffect) - .settings(libraryDependencies += "com.softwaremill.sttp.shared" %% "fs2" % Versions.sttpShared) + .settings(libraryDependencies ++= Seq( + "com.softwaremill.sttp.shared" %% "fs2" % Versions.sttpShared, + "co.fs2" %% "fs2-reactive-streams" % Versions.fs2 + )) lazy val nettyServerZio: ProjectMatrix = nettyServerProject("zio", zio) .settings(libraryDependencies += "dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats) diff --git a/doc/server/netty.md b/doc/server/netty.md index f253265f87..4c714cd145 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -15,11 +15,11 @@ To expose an endpoint using a [Netty](https://netty.io)-based server, first add Then, use: -* `NettyFutureServer().addEndpoints` to expose `Future`-based server endpoints. -* `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. -* `NettyZioServer().addEndpoints` to expose `ZIO`-based server endpoints, where `R` represents ZIO requirements supported effect. +- `NettyFutureServer().addEndpoints` to expose `Future`-based server endpoints. +- `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. [Streaming](../endpoint/streaming.md) request and response bodies is supported with fs2. +- `NettyZioServer().addEndpoints` to expose `ZIO`-based server endpoints, where `R` represents ZIO requirements supported effect. -These methods require a single, or a list of `ServerEndpoint`s, which can be created by adding [server logic](logic.md) +These methods require a single, or a list of `ServerEndpoint`s, which can be created by adding [server logic](logic.md) to an endpoint. For example: @@ -36,7 +36,7 @@ val helloWorld = endpoint .out(stringBody) .serverLogic(name => Future.successful[Either[Unit, String]](Right(s"Hello, $name!"))) -val binding: Future[NettyFutureServerBinding] = +val binding: Future[NettyFutureServerBinding] = NettyFutureServer().addEndpoint(helloWorld).start() ``` @@ -60,7 +60,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.default.socketBacklog(256)) +NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) ``` ## Domain socket support @@ -83,4 +83,4 @@ val serverBinding: Future[NettyFutureDomainSocketBinding] = Future.successful[Either[Unit, String]](Right(s"Hello, $name!"))) ) .startUsingDomainSocket(Paths.get(System.getProperty("java.io.tmpdir"), "hello")) -``` \ No newline at end of file +``` diff --git a/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala new file mode 100644 index 0000000000..a7eb446d6a --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala @@ -0,0 +1,73 @@ +package sttp.tapir.examples.streaming + +import cats.effect.{ExitCode, IO, IOApp} +import cats.implicits._ +import fs2.{Chunk, Stream} +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.model.HeaderNames +import sttp.tapir._ +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.netty.cats.{NettyCatsServer, NettyCatsServerBinding} + +import java.nio.charset.StandardCharsets +import scala.concurrent.duration._ + +object StreamingNettyFs2Server extends IOApp { + // corresponds to: GET /receive?name=... + // We need to provide both the schema of the value (for documentation), as well as the format (media type) of the + // body. Here, the schema is a `string` (set by `streamTextBody`) and the media type is `text/plain`. + val streamingEndpoint: PublicEndpoint[Unit, Unit, (Long, Stream[IO, Byte]), Fs2Streams[IO]] = + endpoint.get + .in("receive") + .out(header[Long](HeaderNames.ContentLength)) + .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))) + + val serverEndpoint: ServerEndpoint[Fs2Streams[IO], IO] = streamingEndpoint + .serverLogicSuccess { _ => + val size = 100L + Stream + .emit(List[Char]('a', 'b', 'c', 'd')) + .repeat + .flatMap(list => Stream.chunk(Chunk.seq(list))) + .metered[IO](100.millis) + .take(size) + .covary[IO] + .map(_.toByte) + .pure[IO] + .map(s => (size, s)) + } + + private val declaredPort = 9090 + private val declaredHost = "localhost" + + override def run(args: List[String]): IO[ExitCode] = { + // starting the server + NettyCatsServer + .io() + .use { server => + + val startServer: IO[NettyCatsServerBinding[IO]] = server + .port(declaredPort) + .host(declaredHost) + .addEndpoint(serverEndpoint) + .start() + + startServer + .map { binding => + + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") + + val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() + val result: String = + basicRequest.response(asStringAlways).get(uri"http://$declaredHost:$declaredPort/receive").send(backend).body + println("Got result: " + result) + + assert(result == "abcd" * 25) + } + .as(ExitCode.Success) + } + } +} diff --git a/project/Versions.scala b/project/Versions.scala index 967e43e267..409d05b6a7 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -19,6 +19,7 @@ object Versions { val finatra = "22.12.0" val catbird = "21.12.0" val json4s = "4.0.6" + val nettyReactiveStreams = "2.0.8" val sprayJson = "1.3.6" val scalaCheck = "1.17.0" val scalaTest = "3.2.16" diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 39f0ba1826..604a13e57d 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -15,15 +15,16 @@ import sttp.tapir.server.netty.{NettyConfig, Route} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import sttp.capabilities.fs2.Fs2Streams case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { - def addEndpoint(se: ServerEndpoint[Any, F]): NettyCatsServer[F] = addEndpoints(List(se)) - def addEndpoint(se: ServerEndpoint[Any, F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Any, F]]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) - def addEndpoints(ses: List[ServerEndpoint[Any, F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(overrideOptions).toRoute(ses) ) @@ -60,7 +61,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f())), + new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength), eventLoopGroup, socketOverride ) @@ -81,9 +82,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty object NettyCatsServer { def apply[F[_]: Async](dispatcher: Dispatcher[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.default) + NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.defaultWithStreaming) def apply[F[_]: Async](options: NettyCatsServerOptions[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, options, NettyConfig.default) + NettyCatsServer(Vector.empty, options, NettyConfig.defaultWithStreaming) def apply[F[_]: Async](dispatcher: Dispatcher[F], config: NettyConfig): NettyCatsServer[F] = NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), config) def apply[F[_]: Async](options: NettyCatsServerOptions[F], config: NettyConfig): NettyCatsServer[F] = diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 3e2d3d9072..6cffa531ce 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -2,28 +2,50 @@ package sttp.tapir.server.netty.cats import cats.effect.Async import cats.effect.std.Dispatcher +import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError +import sttp.monad.syntax._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.Route -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.interceptor.RequestResult +import sttp.tapir.server.interceptor.reject.RejectInterceptor +import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] def nettyServerOptions: NettyCatsServerOptions[F] - def toRoute(ses: List[ServerEndpoint[Any, F]]): Route[F] = { + def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = { + implicit val monad: MonadError[F] = new CatsMonadError[F] val runAsync = new RunAsync[F] { override def apply[T](f: => F[T]): Unit = nettyServerOptions.dispatcher.unsafeRunAndForget(f) } - NettyServerInterpreter.toRoute( - ses, - nettyServerOptions.interceptors, - nettyServerOptions.createFile, - nettyServerOptions.deleteFile, - runAsync + implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) + + val interceptors = nettyServerOptions.interceptors + val createFile = nettyServerOptions.createFile + val deleteFile = nettyServerOptions.deleteFile + + val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( + FilterServerEndpoints(ses), + new NettyCatsRequestBody(createFile), + new NettyCatsToResponseBody(nettyServerOptions.dispatcher, delegate = new NettyToResponseBody), + RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), + deleteFile ) + + val handler: Route[F] = { (request: NettyServerRequest) => + serverInterpreter(request) + .map { + case RequestResult.Response(response) => Some(response) + case RequestResult.Failure(_) => None + } + } + + handler } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala new file mode 100644 index 0000000000..59ab71939d --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -0,0 +1,64 @@ +package sttp.tapir.server.netty.internal + +import cats.effect.{Async, Sync} +import cats.syntax.all._ +import com.typesafe.netty.http.StreamedHttpRequest +import fs2.Chunk +import fs2.interop.reactivestreams.StreamSubscriber +import fs2.io.file.{Files, Path} +import io.netty.buffer.ByteBufUtil +import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} + +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer + +private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) + extends RequestBody[F, Fs2Streams[F]] { + + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + + bodyType match { + case RawBodyType.StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.ByteArrayBody => + nettyRequestBytes(serverRequest).map(RawValue(_)) + case RawBodyType.ByteBufferBody => + nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) + case RawBodyType.InputStreamBody => + nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) + case RawBodyType.InputStreamRangeBody => + nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) + case RawBodyType.FileBody => + createFile(serverRequest) + .flatMap(tapirFile => { + toStream(serverRequest) + .through( + Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(tapirFile.toPath)) + ) + .compile + .drain + .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) + }) + case _: RawBodyType.MultipartBody => ??? + } + } + + override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { + val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] + fs2.Stream + .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.DefaultChunkSize)) + .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) + .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) + } + + private def nettyRequestBytes(serverRequest: ServerRequest): F[Array[Byte]] = serverRequest.underlying match { + case req: FullHttpRequest => monad.delay(ByteBufUtil.getBytes(req.content())) + case _: StreamedHttpRequest => toStream(serverRequest).compile.to(Chunk).map(_.toArray[Byte]) + case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala new file mode 100644 index 0000000000..bcf754da4e --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -0,0 +1,91 @@ +package sttp.tapir.server.netty.internal + +import cats.effect.kernel.{Async, Sync} +import cats.effect.std.Dispatcher +import fs2.Chunk +import fs2.interop.reactivestreams._ +import fs2.io.file.{Files, Flags, Path} +import io.netty.buffer.Unpooled +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import org.reactivestreams.Publisher +import sttp.capabilities.fs2.Fs2Streams +import sttp.model.HasHeaders +import sttp.tapir.server.interpreter.ToResponseBody +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent._ +import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} + +import java.io.InputStream +import java.nio.charset.Charset + +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: NettyToResponseBody) + extends ToResponseBody[NettyResponse, Fs2Streams[F]] { + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { + bodyType match { + + case RawBodyType.InputStreamBody => + val stream = inputStreamToFs2(() => v) + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + + case RawBodyType.InputStreamRangeBody => + val stream = v.range + .map(range => inputStreamToFs2(v.inputStreamFromRangeStart).take(range.contentLength)) + .getOrElse(inputStreamToFs2(v.inputStream)) + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + + case RawBodyType.FileBody => + val tapirFile = v + val path = Path.fromNioPath(tapirFile.file.toPath) + val stream = tapirFile.range + .flatMap(r => r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, r.contentLength.toInt, s._1, s._2))) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) + + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + + case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException + + case _ => delegate.fromRawValue(v, headers, format, bodyType) + } + } + + private def inputStreamToFs2(inputStream: () => InputStream) = + fs2.io.readInputStream( + Sync[F].blocking(inputStream()), + NettyToResponseBody.DefaultChunkSize + ) + + private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { + // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated + // dispatcher, which results in a Resource[], which is hard to afford here + StreamUnicastPublisher( + stream.chunkLimit(NettyToResponseBody.DefaultChunkSize) + .map { chunk => + val bytes: Chunk.ArraySlice[Byte] = chunk.compact + + new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) + }, + dispatcher + ) + } + + override def fromStreamValue( + v: streams.BinaryStream, + headers: HasHeaders, + format: CodecFormat, + charset: Option[Charset] + ): NettyResponse = + (ctx: ChannelHandlerContext) => { + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) + } + + override def fromWebSocketPipe[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): NettyResponse = throw new UnsupportedOperationException +} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 648fd8a69a..1ec8c40b14 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -3,6 +3,7 @@ package sttp.tapir.server.netty.cats import cats.effect.{IO, Resource} import io.netty.channel.nio.NioEventLoopGroup import org.scalatest.EitherValues +import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.netty.internal.FutureUtil @@ -10,6 +11,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} class NettyCatsServerTest extends TestSuite with EitherValues { + override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => Resource @@ -20,7 +22,14 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() + val tests = new AllServerTests( + createServerTest, + interpreter, + backend, + multipart = false, + maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) + ) + .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index f41d7dcbdb..b0847aedb4 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -8,10 +8,11 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port +import sttp.capabilities.fs2.Fs2Streams class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) - extends TestServerInterpreter[IO, Any, NettyCatsServerOptions[IO], Route[IO]] { - override def route(es: List[ServerEndpoint[Any, IO]], interceptors: Interceptors): Route[IO] = { + extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { + override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = { val serverOptions: NettyCatsServerOptions[IO] = interceptors( NettyCatsServerOptions.customiseInterceptors[IO](dispatcher) ).options @@ -19,7 +20,11 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch } override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultWithStreaming + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .maxContentLength(NettyCatsTestServerInterpreter.maxContentLength) val options = NettyCatsServerOptions.default[IO](dispatcher) val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start() @@ -28,3 +33,7 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch .map(_.port) } } + +object NettyCatsTestServerInterpreter { + val maxContentLength = 10000 +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index b65e072d0b..3d5c6af29d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty +import com.typesafe.netty.http.HttpStreamsServerHandler import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.kqueue.{KQueue, KQueueEventLoopGroup, KQueueServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup @@ -9,7 +10,6 @@ import io.netty.handler.codec.http.{HttpObjectAggregator, HttpServerCodec} import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext import io.netty.handler.stream.ChunkedWriteHandler -import io.netty.handler.timeout.ReadTimeoutException import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ @@ -48,7 +48,7 @@ case class NettyConfig( host: String, port: Int, shutdownEventLoopGroupOnClose: Boolean, - maxContentLength: Int, + maxContentLength: Option[Int], socketBacklog: Int, requestTimeout: Option[FiniteDuration], connectionTimeout: Option[FiniteDuration], @@ -69,8 +69,8 @@ case class NettyConfig( def withShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = true) def withDontShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = false) - def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = m) - def noMaxContentLength: NettyConfig = copy(maxContentLength = Integer.MAX_VALUE) + def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = Some(m)) + def noMaxContentLength: NettyConfig = copy(maxContentLength = None) def socketBacklog(s: Int): NettyConfig = copy(socketBacklog = s) @@ -98,7 +98,7 @@ case class NettyConfig( } object NettyConfig { - def default: NettyConfig = NettyConfig( + def defaultNoStreaming: NettyConfig = NettyConfig( host = "localhost", port = 8080, shutdownEventLoopGroupOnClose = true, @@ -108,24 +108,34 @@ object NettyConfig { connectionTimeout = Some(10.seconds), socketTimeout = Some(60.seconds), lingerTimeout = Some(60.seconds), - maxContentLength = Integer.MAX_VALUE, + maxContentLength = None, addLoggingHandler = false, sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipeline(cfg)(_, _) + initPipeline = cfg => defaultInitPipelineNoStreaming(cfg)(_, _) ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipelineNoStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength)) + pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength.getOrElse(Integer.MAX_VALUE))) pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast(handler) + () + } + + def defaultInitPipelineStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) + pipeline.addLast(new HttpServerCodec()) + pipeline.addLast(new HttpStreamsServerHandler()) + pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () } + def defaultWithStreaming: NettyConfig = defaultNoStreaming.copy(initPipeline = cfg => defaultInitPipelineStreaming(cfg)(_, _)) + case class EventLoopConfig(initEventLoopGroup: () => EventLoopGroup, serverChannel: Class[_ <: ServerChannel]) object EventLoopConfig { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index faa18cb069..b53483a0a9 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -55,7 +55,12 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val route = Route.combine(routes) val channelFuture = - NettyBootstrap(config, new NettyServerHandler(route, (f: () => Future[Unit]) => f()), eventLoopGroup, socketOverride) + NettyBootstrap( + config, + new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), + eventLoopGroup, + socketOverride + ) nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup))) } @@ -71,10 +76,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe object NettyFutureServer { def apply()(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.default) + NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultNoStreaming) def apply(serverOptions: NettyFutureServerOptions)(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, serverOptions, NettyConfig.default) + NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) def apply(config: NettyConfig)(implicit ec: ExecutionContext): NettyFutureServer = NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, config) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index dc79aebae2..d3adfe3480 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -2,7 +2,9 @@ package sttp.tapir.server.netty import io.netty.buffer.ByteBuf import io.netty.channel.ChannelPromise +import io.netty.handler.codec.http.HttpContent import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.reactivestreams.Publisher sealed trait NettyResponseContent { def channelPromise: ChannelPromise @@ -13,4 +15,5 @@ object NettyResponseContent { final case class ChunkedStreamNettyResponseContent(channelPromise: ChannelPromise, chunkedStream: ChunkedStream) extends NettyResponseContent final case class ChunkedFileNettyResponseContent(channelPromise: ChannelPromise, chunkedFile: ChunkedFile) extends NettyResponseContent + final case class ReactivePublisherNettyResponseContent(channelPromise: ChannelPromise, publisher: Publisher[HttpContent]) extends NettyResponseContent } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala index bda70cd592..7b9be99466 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala @@ -2,13 +2,14 @@ package sttp.tapir.server.netty import scala.collection.JavaConverters._ import scala.collection.immutable.Seq -import io.netty.handler.codec.http.{FullHttpRequest, QueryStringDecoder} +import io.netty.handler.codec.http.{HttpRequest, QueryStringDecoder} import sttp.model.{Header, Method, QueryParams, Uri} import sttp.tapir.{AttributeKey, AttributeMap} import sttp.tapir.model.{ConnectionInfo, ServerRequest} import sttp.tapir.server.netty.internal.RichNettyHttpHeaders +import io.netty.handler.codec.http.FullHttpRequest -case class NettyServerRequest(req: FullHttpRequest, attributes: AttributeMap = AttributeMap.Empty) extends ServerRequest { +case class NettyServerRequest(req: HttpRequest, attributes: AttributeMap = AttributeMap.Empty) extends ServerRequest { override lazy val protocol: String = req.protocolVersion().text() override lazy val connectionInfo: ConnectionInfo = ConnectionInfo.NoInfo override lazy val underlying: Any = req @@ -25,9 +26,12 @@ case class NettyServerRequest(req: FullHttpRequest, attributes: AttributeMap = A override lazy val method: Method = Method.unsafeApply(req.method().name()) override lazy val uri: Uri = Uri.unsafeParse(req.uri()) override lazy val pathSegments: List[String] = uri.pathSegments.segments.map(_.v).filter(_.nonEmpty).toList - override lazy val headers: Seq[Header] = req.headers().toHeaderSeq ::: req.trailingHeaders().toHeaderSeq + override lazy val headers: Seq[Header] = req.headers().toHeaderSeq ::: (req match { + case full: FullHttpRequest => full.trailingHeaders().toHeaderSeq + case _ => List.empty + }) override def attribute[T](k: AttributeKey[T]): Option[T] = attributes.get(k) override def attribute[T](k: AttributeKey[T], v: T): NettyServerRequest = copy(attributes = attributes.put(k, v)) override def withUnderlying(underlying: Any): ServerRequest = - NettyServerRequest(req = underlying.asInstanceOf[FullHttpRequest], attributes) + NettyServerRequest(req = underlying.asInstanceOf[HttpRequest], attributes) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index e32c06fc46..6c9d251676 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -9,7 +9,6 @@ import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.server.netty.NettyServerRequest import java.nio.ByteBuffer import java.nio.file.Files @@ -46,3 +45,7 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } + +private[internal] object NettyRequestBody { + val DefaultChunkSize = 8192 +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 7842f6b3c6..645521965f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -1,75 +1,110 @@ package sttp.tapir.server.netty.internal +import com.typesafe.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ +import io.netty.handler.codec.http.HttpHeaderNames.{CONNECTION, CONTENT_LENGTH} import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.reactivestreams.Publisher import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import sttp.tapir.server.netty.NettyResponseContent.{ ByteBufNettyResponseContent, ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent + ChunkedStreamNettyResponseContent, + ReactivePublisherNettyResponseContent } +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import scala.collection.JavaConverters._ -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) - extends SimpleChannelInboundHandler[FullHttpRequest] { +class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit + me: MonadError[F] +) extends SimpleChannelInboundHandler[HttpRequest] { private val logger = Logger[NettyServerHandler[F]] - override def channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest): Unit = { + private val EntityTooLarge: FullHttpResponse = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) + res.headers().set(CONTENT_LENGTH, 0) + res + } + + private val EntityTooLargeClose: FullHttpResponse = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) + res.headers().set(CONTENT_LENGTH, 0) + res.headers().set(CONNECTION, HttpHeaderValues.CLOSE) + res + } + + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + + def runRoute(req: HttpRequest) = { + + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + .flatMap((serverResponse: ServerResponse[NettyResponse]) => + // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that + // we get the 500 response, instead of dropping the request + try handleResponse(ctx, req, serverResponse).unit + catch { + case e: Exception => me.error[Unit](e) + } + ) + .handleError { case ex: Exception => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + me.unit(()) + } + } + if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () } else { - val req = request.retain() - - unsafeRunAsync { () => - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound + request match { + case full: FullHttpRequest => + val req = full.retain() + unsafeRunAsync { () => + runRoute(req) + .ensure(me.eval(req.release())) + } // exceptions should be handled + case req: StreamedHttpRequest => + unsafeRunAsync { () => + runRoute(req) } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - me.unit(()) - } - .ensure(me.eval(req.release())) - } // exceptions should be handled + case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") + } + () } } - private def handleResponse(ctx: ChannelHandlerContext, req: FullHttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = + private def handleResponse(ctx: ChannelHandlerContext, req: HttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - res.handleCloseAndKeepAliveHeaders(req) - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + if (maxContentLength.exists(_ < byteBuf.readableBytes)) + writeEntityTooLargeResponse(ctx, req) + else { + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + } }, chunkedStreamHandler = (channelPromise, chunkedStream) => { val resHeader: DefaultHttpResponse = @@ -89,11 +124,19 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) resHeader.setHeadersFrom(serverResponse) resHeader.handleContentLengthAndChunkedHeaders(Option(chunkedFile.length())) resHeader.handleCloseAndKeepAliveHeaders(req) - ctx.write(resHeader) // HttpChunkedInput will write the end marker (LastHttpContent) for us. ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) }, + reactiveStreamHandler = (channelPromise, publisher) => { + val res: DefaultStreamedHttpResponse = + new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) + + res.setHeadersFrom(serverResponse) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + + }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( req.protocolVersion(), @@ -109,12 +152,39 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } ) + private def writeEntityTooLargeResponse(ctx: ChannelHandlerContext, req: HttpRequest): Unit = { + + if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { + val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) + future.addListener(new ChannelFutureListener() { + override def operationComplete(future: ChannelFuture) = { + if (!future.isSuccess()) { + logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) + } + ctx.close() + } + }) + } else { + ctx + .writeAndFlush(EntityTooLarge.retainedDuplicate()) + .addListener(new ChannelFutureListener() { + override def operationComplete(future: ChannelFuture) = { + if (!future.isSuccess()) { + logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) + ctx.close() + } + } + }) + } + } + private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, byteBufHandler: (ChannelPromise, ByteBuf) => Unit, chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, + reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -122,9 +192,10 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) val values = function(ctx) values match { - case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) - case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) - case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) + case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) + case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) } } case None => noBodyHandler() @@ -149,7 +220,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) if (lengthUnknownAndChunkedShouldBeUsed) { m.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) } } - def handleCloseAndKeepAliveHeaders(request: FullHttpRequest): Unit = { + def handleCloseAndKeepAliveHeaders(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request)) m.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) else if (request.protocolVersion.equals(HttpVersion.HTTP_1_0)) @@ -158,7 +229,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } private implicit class RichChannelFuture(val cf: ChannelFuture) { - def closeIfNeeded(request: FullHttpRequest): Unit = { + def closeIfNeeded(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request)) { cf.addListener(ChannelFutureListener.CLOSE) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index 8a060c56cc..dd39ebef31 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -95,8 +95,8 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { ): NettyResponse = throw new UnsupportedOperationException } -object NettyToResponseBody { - private val DefaultChunkSize = 8192 - private val IncludingLastOffset = 1 - private val ReadOnlyAccessMode = "r" +private[internal] object NettyToResponseBody { + val DefaultChunkSize = 8192 + val IncludingLastOffset = 1 + val ReadOnlyAccessMode = "r" } diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index 15b9b9ab30..37f2f2481a 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -18,7 +18,7 @@ class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implic } override def server(routes: NonEmptyList[FutureRoute]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose val options = NettyFutureServerOptions.default val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addRoutes(routes.toList).start())) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 91e139e056..cfaa64cf81 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -66,7 +66,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: config, new NettyServerHandler[RIO[R, *]]( route, - (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())) + (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())), + config.maxContentLength ), eventLoopGroup, socketOverride @@ -92,8 +93,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: } object NettyZioServer { - def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.default) - def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = NettyZioServer(Vector.empty, options, NettyConfig.default) + def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.defaultNoStreaming) + def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = NettyZioServer(Vector.empty, options, NettyConfig.defaultNoStreaming) def apply[R](config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], config) def apply[R](options: NettyZioServerOptions[R], config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, options, config) } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala index 629bec2ca8..1c66df45a4 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala @@ -21,7 +21,7 @@ class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) } override def server(routes: NonEmptyList[Task[Route[Task]]]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose val options = NettyZioServerOptions.default[R] val runtime: Runtime[R] = Runtime.default.asInstanceOf[Runtime[R]] diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 96eea0ee72..369f0cc9b6 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -27,13 +27,14 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( validation: Boolean = true, oneOfBody: Boolean = true, cors: Boolean = true, - options: Boolean = true + options: Boolean = true, + maxContentLength: Option[Int] = None )(implicit m: MonadError[F] ) { def tests(): List[Test] = (if (security) new ServerSecurityTests(createServerTest).tests() else Nil) ++ - (if (basic) new ServerBasicTests(createServerTest, serverInterpreter).tests() else Nil) ++ + (if (basic) new ServerBasicTests(createServerTest, serverInterpreter, maxContentLength = maxContentLength).tests() else Nil) ++ (if (contentNegotiation) new ServerContentNegotiationTests(createServerTest).tests() else Nil) ++ (if (file) new ServerFileTests(createServerTest).tests() else Nil) ++ (if (mapping) new ServerMappingTests(createServerTest).tests() else Nil) ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 5122c4fb8d..e82f8aa2dd 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -15,7 +15,7 @@ import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor.decodefailure.{DecodeFailureHandler, DefaultDecodeFailureHandler} +import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler import sttp.tapir.tests.Basic._ import sttp.tapir.tests.TestUtil._ import sttp.tapir.tests._ @@ -31,7 +31,8 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( inputStreamSupport: Boolean = true, supportsUrlEncodedPathSegments: Boolean = true, supportsMultipleSetCookieHeaders: Boolean = true, - invulnerableToUnsanitizedHeaders: Boolean = true + invulnerableToUnsanitizedHeaders: Boolean = true, + maxContentLength: Option[Int] = None )(implicit m: MonadError[F] ) { @@ -46,6 +47,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ + (if (maxContentLength.nonEmpty) maxContentLengthTests() else Nil) ++ exceptionTests() def basicTests(): List[Test] = List( @@ -742,6 +744,12 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) + def maxContentLengthTests(): List[Test] = List( + testServer(in_string_out_string, "returns 413 on exceeded max content length")(_ => + pureResult(List.fill(maxContentLength.getOrElse(0) + 1)('x').mkString.asRight[Unit]) + ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("irrelevant").send(backend).map(_.code.code shouldBe 413) } + ) + def exceptionTests(): List[Test] = List( testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.InternalServerError) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 3369f6c009..13ee47524a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -39,11 +39,7 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ .send(backend) .map { response => response.body shouldBe Right(penPineapple) - if (response.headers.contains(Header(HeaderNames.TransferEncoding, "chunked"))) { - response.contentLength shouldBe None - } else { - response.contentLength shouldBe Some(penPineapple.length) - } + response.contentLength shouldBe Some(penPineapple.length) } } },