diff --git a/build.sbt b/build.sbt index 7bcecb9a2b..66571022f6 100644 --- a/build.sbt +++ b/build.sbt @@ -2044,7 +2044,8 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples")) scalaTest.value ), libraryDependencies ++= loggerDependencies, - publishArtifact := false + publishArtifact := false, + Compile / run / fork := true ) .jvmPlatform(scalaVersions = examplesScalaVersions) .dependsOn( diff --git a/doc/server/netty.md b/doc/server/netty.md index abb6eb25aa..f53228d530 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -63,6 +63,21 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None) NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) ``` +## Graceful shutdown + +A Netty should can be gracefully closed using function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources. +You can customize this behavior in `NettyConfig`: + +```scala mdoc:compile-only +import sttp.tapir.server.netty.NettyConfig +import scala.concurrent.duration._ + +// adjust the waiting time to your needs +val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds) +// or if you don't want the server to wait for in-flight requests +val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown +``` + ## Domain socket support There is possibility to use Domain socket instead of TCP for handling traffic. diff --git a/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettyCatsServer.scala b/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettyCatsServer.scala index a08185bc60..9ab3d19be5 100644 --- a/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettyCatsServer.scala +++ b/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettyCatsServer.scala @@ -1,13 +1,13 @@ package sttp.tapir.examples import cats.effect.IO +import cats.effect.IOApp import sttp.client3.{HttpURLConnectionBackend, Identity, SttpBackend, UriContext, asStringAlways, basicRequest} import sttp.model.StatusCode import sttp.tapir.{PublicEndpoint, endpoint, query, stringBody} -import cats.effect.unsafe.implicits.global -import sttp.tapir.server.netty.cats.{NettyCatsServer, NettyCatsServerBinding} +import sttp.tapir.server.netty.cats.NettyCatsServer -object HelloWorldNettyCatsServer extends App { +object HelloWorldNettyCatsServer extends IOApp.Simple { // One endpoint on GET /hello with query parameter `name` val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = endpoint.get.in("hello").in(query[String]("name")).out(stringBody) @@ -20,37 +20,37 @@ object HelloWorldNettyCatsServer extends App { private val declaredHost = "localhost" // Creating handler for netty bootstrap - NettyCatsServer + override def run = NettyCatsServer .io() .use { server => - - val effect: IO[NettyCatsServerBinding[IO]] = server - .port(declaredPort) - .host(declaredHost) - .addEndpoint(helloWorldServerEndpoint) - .start() - - effect.map { binding => - - val port = binding.port - val host = binding.hostName - println(s"Server started at port = ${binding.port}") - - val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() - val badUrl = uri"http://$host:$port/bad_url" - assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404)) - - val noQueryParameter = uri"http://$host:$port/hello" - assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400)) - - val allGood = uri"http://$host:$port/hello?name=Netty" - val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body - - println("Got result: " + body) - assert(body == "Hello, Netty!") - assert(port == declaredPort, "Ports don't match") - assert(host == declaredHost, "Hosts don't match") - } + for { + binding <- server + .port(declaredPort) + .host(declaredHost) + .addEndpoint(helloWorldServerEndpoint) + .start() + result <- IO + .blocking { + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") + + val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() + val badUrl = uri"http://$host:$port/bad_url" + assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404)) + + val noQueryParameter = uri"http://$host:$port/hello" + assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400)) + + val allGood = uri"http://$host:$port/hello?name=Netty" + val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body + + println("Got result: " + body) + assert(body == "Hello, Netty!") + assert(port == declaredPort, "Ports don't match") + assert(host == declaredHost, "Hosts don't match") + } + .guarantee(binding.stop()) + } yield result } - .unsafeRunSync() } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 7e7b3c8fea..340ebd76dc 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -1,10 +1,13 @@ package sttp.tapir.server.netty.cats +import cats.effect.kernel.Sync import cats.effect.std.Dispatcher -import cats.effect.{Async, IO, Resource} +import cats.effect.{Async, IO, Resource, Temporal} import cats.syntax.all._ import io.netty.channel._ +import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress +import io.netty.util.concurrent.DefaultEventExecutor import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError @@ -17,7 +20,9 @@ import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future +import scala.concurrent.duration._ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) @@ -62,26 +67,57 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[F] = new CatsMonadError[F]() val route: Route[F] = Route.combine(routes) + val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe + val isShuttingDown: AtomicBoolean = new AtomicBoolean(false) val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) - nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup))) + nettyChannelFutureToScala(channelFuture).map(ch => + ( + ch.localAddress().asInstanceOf[SA], + () => stop(ch, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout) + ) + ) } - private def stop(ch: Channel, eventLoopGroup: EventLoopGroup): F[Unit] = { - Async[F].defer { - nettyFutureToScala(ch.close()).flatMap { _ => - if (config.shutdownEventLoopGroupOnClose) { - nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) - } else Async[F].unit - } + private def waitForClosedChannels( + channelGroup: ChannelGroup, + startNanos: Long, + gracefulShutdownTimeoutNanos: Option[Long] + ): F[Unit] = + if (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) { + Temporal[F].sleep(100.millis) >> + waitForClosedChannels(channelGroup, startNanos, gracefulShutdownTimeoutNanos) + } else { + Sync[F].delay(nettyFutureToScala(channelGroup.close())).void } + + private def stop( + ch: Channel, + eventLoopGroup: EventLoopGroup, + channelGroup: ChannelGroup, + isShuttingDown: AtomicBoolean, + gracefulShutdownTimeout: Option[FiniteDuration] + ): F[Unit] = { + Sync[F].delay(isShuttingDown.set(true)) >> + waitForClosedChannels( + channelGroup, + startNanos = System.nanoTime(), + gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) + ) >> + Async[F].defer { + nettyFutureToScala(ch.close()).flatMap { _ => + if (config.shutdownEventLoopGroupOnClose) { + nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) + } else Async[F].unit + } + } } } diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 7b3f2a1303..cde65b3d38 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -11,6 +11,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration class NettyCatsServerTest extends TestSuite with EitherValues { @@ -23,6 +24,9 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) + val ioSleeper: Sleeper[IO] = new Sleeper[IO] { + override def sleep(duration: FiniteDuration): IO[Unit] = IO.sleep(duration) + } val tests = new AllServerTests( createServerTest, @@ -34,7 +38,8 @@ class NettyCatsServerTest extends TestSuite with EitherValues { .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ - new NettyFs2StreamingCancellationTest(createServerTest).tests() + new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ + new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index b0847aedb4..68e74d197b 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -9,6 +9,7 @@ import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port import sttp.capabilities.fs2.Fs2Streams +import scala.concurrent.duration.FiniteDuration class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { @@ -19,18 +20,28 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch NettyCatsServerInterpreter(serverOptions).toRoute(es) } - override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = { + override def serverWithStop( + routes: NonEmptyList[Route[IO]], + gracefulShutdownTimeout: Option[FiniteDuration] = None + ): Resource[IO, (Port, IO[Unit])] = { val config = NettyConfig.defaultWithStreaming .eventLoopGroup(eventLoopGroup) .randomPort .withDontShutdownEventLoopGroupOnClose .maxContentLength(NettyCatsTestServerInterpreter.maxContentLength) + .noGracefulShutdown + + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettyCatsServerOptions.default[IO](dispatcher) - val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start() + val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, customizedConfig).addRoutes(routes.toList).start() Resource - .make(bind)(_.stop()) - .map(_.port) + .make(bind.map(b => (b, b.stop()))) { case (_, stop) => stop } + .map { case (b, stop) => (b.port, stop) } + } + + override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = { + serverWithStop(routes).map(_._1) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index ee37e6c7c8..fbadd899fe 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -47,6 +47,10 @@ import scala.concurrent.duration._ * @param lingerTimeout * Sets the delay for which the Netty waits, while data is being transmitted, before closing a socket after receiving a call to close the * socket + * + * @param gracefulShutdownTimeout + * If set, attempts to wait for a given time for all in-flight requests to complete, before proceeding with shutting down the server. If + * `None`, closes the channels and terminates the server without waiting. */ case class NettyConfig( host: String, @@ -64,7 +68,8 @@ case class NettyConfig( sslContext: Option[SslContext], eventLoopConfig: EventLoopConfig, socketConfig: NettySocketConfig, - initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit + initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit, + gracefulShutdownTimeout: Option[FiniteDuration] ) { def host(h: String): NettyConfig = copy(host = h) @@ -102,6 +107,9 @@ case class NettyConfig( def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) + + def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) + def noGracefulShutdown = copy(gracefulShutdownTimeout = None) } object NettyConfig { @@ -115,6 +123,7 @@ object NettyConfig { connectionTimeout = Some(10.seconds), socketTimeout = Some(60.seconds), lingerTimeout = Some(60.seconds), + gracefulShutdownTimeout = Some(10.seconds), maxContentLength = None, maxConnections = None, addLoggingHandler = false, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index d92c704b02..eaa7a86fe8 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -1,17 +1,21 @@ package sttp.tapir.server.netty import io.netty.channel._ +import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress +import io.netty.util.concurrent.DefaultEventExecutor import sttp.monad.{FutureMonad, MonadError} import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.internal.FutureUtil._ import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID -import scala.concurrent.{ExecutionContext, Future} -import sttp.tapir.server.model.ServerResponse +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.{ExecutionContext, Future, blocking} case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext @@ -60,23 +64,57 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[Future] = new FutureMonad() val route = Route.combine(routes) + val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe + val isShuttingDown: AtomicBoolean = new AtomicBoolean(false) val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown), eventLoopGroup, socketOverride ) - nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup))) + nettyChannelFutureToScala(channelFuture).map(ch => + (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout)) + ) } - private def stop(ch: Channel, eventLoopGroup: EventLoopGroup): Future[Unit] = { - nettyFutureToScala(ch.close()).flatMap { _ => - if (config.shutdownEventLoopGroupOnClose) { - nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) - } else Future.successful(()) + private def waitForClosedChannels( + channelGroup: ChannelGroup, + startNanos: Long, + gracefulShutdownTimeoutNanos: Option[Long] + ): Future[Unit] = + if (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) { + Future { + blocking { + Thread.sleep(100) + } + }.flatMap(_ => { + waitForClosedChannels(channelGroup, startNanos, gracefulShutdownTimeoutNanos) + }) + } else { + nettyFutureToScala(channelGroup.close()).map(_ => ()) + } + + private def stop( + ch: Channel, + eventLoopGroup: EventLoopGroup, + channelGroup: ChannelGroup, + isShuttingDown: AtomicBoolean, + gracefulShutdownTimeout: Option[FiniteDuration] + ): Future[Unit] = { + isShuttingDown.set(true) + waitForClosedChannels( + channelGroup, + startNanos = System.nanoTime(), + gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) + ).flatMap { _ => + nettyFutureToScala(ch.close()).flatMap { _ => + if (config.shutdownEventLoopGroupOnClose) { + nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) + } else Future.successful(()) + } } } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index d5bf90f4c7..bef744142d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -1,12 +1,13 @@ package sttp.tapir.server.netty.internal -import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ +import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http.HttpHeaderNames.{CONNECTION, CONTENT_LENGTH} import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher import sttp.monad.MonadError import sttp.monad.syntax._ @@ -19,13 +20,12 @@ import sttp.tapir.server.netty.NettyResponseContent.{ } import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.{Queue => MutableQueue} -import scala.concurrent.Future -import scala.util.Failure -import scala.util.Success +import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal -import scala.concurrent.ExecutionContext +import scala.util.{Failure, Success} /** @param unsafeRunAsync * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => @@ -35,7 +35,9 @@ import scala.concurrent.ExecutionContext class NettyServerHandler[F[_]]( route: Route[F], unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), - maxContentLength: Option[Int] + maxContentLength: Option[Int], + channelGroup: ChannelGroup, + isShuttingDown: AtomicBoolean )(implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { @@ -79,7 +81,10 @@ class NettyServerHandler[F[_]]( if (ctx.channel.isActive) { initHandler(ctx) } - override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx) + override def channelActive(ctx: ChannelHandlerContext): Unit = { + channelGroup.add(ctx.channel) + initHandler(ctx) + } private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { if (eventLoopContext == null) { @@ -99,13 +104,18 @@ class NettyServerHandler[F[_]]( def writeError500(req: HttpRequest, reason: Throwable): Unit = { logger.error("Error while processing the request", reason) - // send 500 val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) res.handleCloseAndKeepAliveHeaders(req) ctx.writeAndFlush(res).closeIfNeeded(req) + } + def writeError503(req: HttpRequest): Unit = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res).closeIfNeeded(req) } def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { @@ -122,11 +132,11 @@ class NettyServerHandler[F[_]]( case Success(serverResponse) => pendingResponses.dequeue() try { - handleResponse(ctx, req, serverResponse) + handleResponse(ctx, req, serverResponse) Success(()) } catch { case NonFatal(ex) => - writeError500(req, ex) + writeError500(req, ex) Failure(ex) } finally { val _ = releaseReq() @@ -135,8 +145,7 @@ class NettyServerHandler[F[_]]( try { writeError500(req, ex) Failure(ex) - } - finally { + } finally { val _ = releaseReq() } case Failure(fatalException) => Failure(fatalException) @@ -144,7 +153,10 @@ class NettyServerHandler[F[_]]( }(eventLoopContext) } - if (HttpUtil.is100ContinueExpected(request)) { + if (isShuttingDown.get()) { + logger.info("Rejecting request, server is shutting down") + writeError503(request) + } else if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () } else { @@ -291,7 +303,7 @@ class NettyServerHandler[F[_]]( } def handleCloseAndKeepAliveHeaders(request: HttpRequest): Unit = { - if (!HttpUtil.isKeepAlive(request)) + if (!HttpUtil.isKeepAlive(request) || isShuttingDown.get()) m.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) else if (request.protocolVersion.equals(HttpVersion.HTTP_1_0)) m.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) @@ -300,7 +312,7 @@ class NettyServerHandler[F[_]]( private implicit class RichChannelFuture(val cf: ChannelFuture) { def closeIfNeeded(request: HttpRequest): Unit = { - if (!HttpUtil.isKeepAlive(request)) { + if (!HttpUtil.isKeepAlive(request) || isShuttingDown.get()) { cf.addListener(ChannelFutureListener.CLOSE) } } diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala index 9b24ca83ed..b7d86c5e4a 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureServerTest.scala @@ -21,7 +21,8 @@ class NettyFutureServerTest extends TestSuite with EitherValues { val interpreter = new NettyFutureTestServerInterpreter(eventLoopGroup) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() + val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ + new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() (tests, eventLoopGroup) }) { case (_, eventLoopGroup) => diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index 37f2f2481a..174bedd65a 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -8,22 +8,37 @@ import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.FiniteDuration -class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implicit ec: ExecutionContext) - extends TestServerInterpreter[Future, Any, NettyFutureServerOptions, FutureRoute] { +class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implicit + ec: ExecutionContext +) extends TestServerInterpreter[Future, Any, NettyFutureServerOptions, FutureRoute] { override def route(es: List[ServerEndpoint[Any, Future]], interceptors: Interceptors): FutureRoute = { val serverOptions = interceptors(NettyFutureServerOptions.customiseInterceptors).options NettyFutureServerInterpreter(serverOptions).toRoute(es) } - override def server(routes: NonEmptyList[FutureRoute]): Resource[IO, Port] = { - val config = NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + override def serverWithStop( + routes: NonEmptyList[FutureRoute], + gracefulShutdownTimeout: Option[FiniteDuration] = None + ): Resource[IO, (Port, IO[Unit])] = { + val config = + NettyConfig.defaultNoStreaming + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .noGracefulShutdown + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettyFutureServerOptions.default - val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addRoutes(routes.toList).start())) + val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, customizedConfig).addRoutes(routes.toList).start())) Resource - .make(bind)(binding => IO.fromFuture(IO.delay(binding.stop()))) - .map(b => b.port) + .make(bind.map(b => (b, IO.fromFuture(IO.delay(b.stop()))))) { case (_, stop) => stop } + .map { case (b, stop) => (b.port, stop) } + } + + override def server(routes: NonEmptyList[FutureRoute]): Resource[IO, Port] = { + serverWithStop(routes).map(_._1) } } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index c642949bc6..968e0d1013 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -1,7 +1,9 @@ package sttp.tapir.server.netty.zio import io.netty.channel._ +import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress +import io.netty.util.concurrent.DefaultEventExecutor import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse @@ -9,13 +11,16 @@ import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} -import zio.{RIO, Unsafe, ZIO} +import zio.{RIO, Task, Unsafe, ZIO, durationInt} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import java.util.concurrent.atomic.AtomicBoolean + import scala.concurrent.ExecutionContext.Implicits import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) @@ -73,6 +78,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: runtime <- ZIO.runtime[R] routes <- ZIO.foreach(routes)(identity) eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() + channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe + isShuttingDown = new AtomicBoolean(false) channelFuture = { implicit val monadError: RIOMonadError[R] = new RIOMonadError[R] val route: Route[RIO[R, *]] = Route.combine(routes) @@ -82,7 +89,9 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: new NettyServerHandler[RIO[R, *]]( route, unsafeRunAsync(runtime), - config.maxContentLength + config.maxContentLength, + channelGroup, + isShuttingDown ), eventLoopGroup, socketOverride @@ -91,19 +100,43 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: binding <- nettyChannelFutureToScala(channelFuture).map(ch => ( ch.localAddress().asInstanceOf[SA], - () => stop(ch, eventLoopGroup) + () => stop(ch, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout) ) ) } yield binding - private def stop(ch: Channel, eventLoopGroup: EventLoopGroup): RIO[R, Unit] = { - ZIO.suspend { - nettyFutureToScala(ch.close()).flatMap { _ => - if (config.shutdownEventLoopGroupOnClose) { - nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) - } else ZIO.succeed(()) - } + private def waitForClosedChannels( + channelGroup: ChannelGroup, + startNanos: Long, + gracefulShutdownTimeoutNanos: Option[Long] + ): Task[Unit] = + if (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) { + ZIO.sleep(100.millis) *> + waitForClosedChannels(channelGroup, startNanos, gracefulShutdownTimeoutNanos) + } else { + ZIO.attempt(channelGroup.close()).unit } + + private def stop( + ch: Channel, + eventLoopGroup: EventLoopGroup, + channelGroup: ChannelGroup, + isShuttingDown: AtomicBoolean, + gracefulShutdownTimeout: Option[FiniteDuration] + ): RIO[R, Unit] = { + ZIO.attempt(isShuttingDown.set(true)) *> + waitForClosedChannels( + channelGroup, + startNanos = System.nanoTime(), + gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) + ) *> + ZIO.suspend { + nettyFutureToScala(ch.close()).flatMap { _ => + if (config.shutdownEventLoopGroupOnClose) { + nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ()) + } else ZIO.succeed(()) + } + } } } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index 227d9223db..1486c0bb6c 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -9,10 +9,11 @@ import sttp.tapir.server.netty.internal.FutureUtil import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import sttp.tapir.ztapir.RIOMonadError -import zio.Task import zio.interop.catz._ +import zio.{Task, ZIO} import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration class NettyZioServerTest extends TestSuite with EitherValues { override def tests: Resource[IO, List[Test]] = @@ -24,11 +25,15 @@ class NettyZioServerTest extends TestSuite with EitherValues { val interpreter = new NettyZioTestServerInterpreter(eventLoopGroup) val createServerTest = new DefaultCreateServerTest(backend, interpreter) + val zioSleeper: Sleeper[Task] = new Sleeper[Task] { + override def sleep(duration: FiniteDuration): Task[Unit] = ZIO.sleep(zio.Duration.fromScala(duration)) + } val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ - new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() + new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() ++ + new ServerGracefulShutdownTests(createServerTest, zioSleeper).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala index 5bc3bcf604..9f00d05386 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala @@ -8,7 +8,9 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port -import zio.{CancelableFuture, Runtime, Task, Unsafe} +import zio.{Runtime, Task, Unsafe} + +import scala.concurrent.duration.FiniteDuration class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) extends TestServerInterpreter[Task, ZioStreams, NettyZioServerOptions[Any], Task[Route[Task]]] { @@ -19,19 +21,38 @@ class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) NettyZioServerInterpreter(serverOptions).toRoute(es) } - override def server(routes: NonEmptyList[Task[Route[Task]]]): Resource[IO, Port] = { - val config = NettyConfig.defaultWithStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + override def serverWithStop( + routes: NonEmptyList[Task[Route[Task]]], + gracefulShutdownTimeout: Option[FiniteDuration] = None + ): Resource[IO, (Port, IO[Unit])] = { + val config = NettyConfig.defaultWithStreaming + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .noGracefulShutdown + + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettyZioServerOptions.default[R] val runtime: Runtime[R] = Runtime.default.asInstanceOf[Runtime[R]] - val server: CancelableFuture[NettyZioServerBinding[R]] = - Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(NettyZioServer(options, config).addRoutes(routes.toList).start())) + val bind: IO[NettyZioServerBinding[R]] = + IO.fromFuture( + IO.delay( + Unsafe.unsafe(implicit u => + runtime.unsafe.runToFuture(NettyZioServer(options, customizedConfig).addRoutes(routes.toList).start()) + ) + ) + ) Resource - .make(IO.fromFuture(IO(server)))(binding => - IO.fromFuture(IO(Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(binding.stop())))) - ) - .map(b => b.port) + .make(bind.map(b => (b, IO.fromFuture[Unit](IO(Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(b.stop()))))))) { + case (_, stop) => stop + } + .map { case (b, stop) => (b.port, stop) } + } + + override def server(routes: NonEmptyList[Task[Route[Task]]]): Resource[IO, Port] = { + serverWithStop(routes).map(_._1) } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala index 950d018d67..2195898125 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala @@ -14,6 +14,8 @@ import sttp.tapir._ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.CustomiseInterceptors import sttp.tapir.tests._ +import org.scalactic.anyvals.FiniteDouble +import scala.concurrent.duration.FiniteDuration trait CreateServerTest[F[_], +R, OPTIONS, ROUTE] { protected type Interceptors = CustomiseInterceptors[F, OPTIONS] => CustomiseInterceptors[F, OPTIONS] @@ -32,9 +34,25 @@ trait CreateServerTest[F[_], +R, OPTIONS, ROUTE] { runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] ): Test + def testServerLogicWithStop( + e: ServerEndpoint[R, F], + testNameSuffix: String = "", + interceptors: Interceptors = identity, + gracefulShutdownTimeout: Option[FiniteDuration] = None + )( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = testServerLogic(e, testNameSuffix, interceptors)(runTest(IO.unit)) + def testServer(name: String, rs: => NonEmptyList[ROUTE])( runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] ): Test + + /** Override for a server to allow running tests which have access to a stop() effect, allowing shutting down the server within the test. + * By default, this method just uses a no-op IO.unit. + */ + def testServerWithStop(name: String, rs: => NonEmptyList[ROUTE], gracefulShutdownTimeout: Option[FiniteDuration])( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = testServer(name, rs)(runTest(IO.unit)) } class DefaultCreateServerTest[F[_], +R, OPTIONS, ROUTE]( @@ -56,6 +74,20 @@ class DefaultCreateServerTest[F[_], +R, OPTIONS, ROUTE]( )(runTest) } + override def testServerLogicWithStop( + e: ServerEndpoint[R, F], + testNameSuffix: String = "", + interceptors: Interceptors = identity, + gracefulShutdownTimeout: Option[FiniteDuration] = None + )( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + testServerWithStop( + e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix), + NonEmptyList.of(interpreter.route(e, interceptors)), + gracefulShutdownTimeout + )(runTest) + } override def testServerLogic( e: ServerEndpoint[R, F], testNameSuffix: String = "", @@ -69,6 +101,24 @@ class DefaultCreateServerTest[F[_], +R, OPTIONS, ROUTE]( )(runTest) } + override def testServerWithStop(name: String, rs: => NonEmptyList[ROUTE], gracefulShutdownTimeout: Option[FiniteDuration])( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] + ): Test = { + val resources = for { + portAndStop <- interpreter.serverWithStop(rs, gracefulShutdownTimeout).onError { case e: Exception => + Resource.eval(IO(logger.error(s"Starting server failed because of ${e.getMessage}"))) + } + _ <- Resource.eval(IO(logger.info(s"Bound server on port: ${portAndStop._1}"))) + } yield portAndStop + + Test(name)( + resources + .use { case (port, stopServer) => + runTest(stopServer)(backend, uri"http://localhost:$port").guarantee(IO(logger.info(s"Tests completed on port $port"))) + } + .unsafeToFuture() + ) + } override def testServer(name: String, rs: => NonEmptyList[ROUTE])( runTest: (SttpBackend[IO, Fs2Streams[IO] with WebSockets], Uri) => IO[Assertion] ): Test = { diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerGracefulShutdownTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerGracefulShutdownTests.scala new file mode 100644 index 0000000000..e950b81b23 --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerGracefulShutdownTests.scala @@ -0,0 +1,64 @@ +package sttp.tapir.server.tests + +import cats.effect.IO +import cats.syntax.all._ +import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers._ +import sttp.client3._ +import sttp.model.StatusCode +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir._ +import sttp.tapir.tests._ + +import scala.concurrent.duration._ + +class ServerGracefulShutdownTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], sleeper: Sleeper[F])(implicit + m: MonadError[F] +) extends EitherValues { + import createServerTest._ + + def tests(): List[Test] = List( + testServerLogicWithStop( + endpoint + .out(plainBody[String]) + .serverLogic { _ => + sleeper.sleep(3.seconds).flatMap(_ => pureResult("processing finished".asRight[Unit])) + }, + "Server waits for long-running request to complete within timeout", + gracefulShutdownTimeout = Some(4.seconds) + ) { (stopServer) => (backend, baseUri) => + (for { + runningRequest <- basicRequest.get(uri"$baseUri").send(backend).start + _ <- IO.sleep(1.second) + runningStop <- stopServer.start + result <- runningRequest.join.attempt + _ <- runningStop.join + } yield { + result.value.isSuccess shouldBe true + }) + }, + testServerLogicWithStop( + endpoint + .out(plainBody[String]) + .serverLogic { _ => + sleeper.sleep(4.seconds).flatMap(_ => pureResult("processing finished".asRight[Unit])) + }, + "Server rejects requests with 503 during shutdown", + gracefulShutdownTimeout = Some(6.seconds) + ) { (stopServer) => (backend, baseUri) => + (for { + runningRequest <- basicRequest.get(uri"$baseUri").send(backend).start + _ <- IO.sleep(1.second) + runningStop <- stopServer.start + _ <- IO.sleep(1.seconds) + rejected <- basicRequest.get(uri"$baseUri").send(backend).attempt + firstResult <- runningRequest.join.attempt + _ <- runningStop.join + } yield { + (rejected.value.code shouldBe StatusCode.ServiceUnavailable): Unit + firstResult.value.isSuccess shouldBe true + }) + } + ) +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/Sleeper.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/Sleeper.scala new file mode 100644 index 0000000000..1e29cefe94 --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/Sleeper.scala @@ -0,0 +1,20 @@ +package sttp.tapir.server.tests + +import scala.concurrent.duration._ +import scala.concurrent.blocking +import scala.concurrent.Future +import scala.concurrent.ExecutionContext + +trait Sleeper[F[_]] { + def sleep(duration: FiniteDuration): F[Unit] +} + +object Sleeper { + def futureSleeper(implicit ec: ExecutionContext): Sleeper[Future] = new Sleeper[Future] { + override def sleep(duration: FiniteDuration): Future[Unit] = Future { + blocking { + Thread.sleep(duration.toMillis) + } + } + } +} diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/TestServerInterpreter.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/TestServerInterpreter.scala index 50751cb9cd..9b848b9364 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/TestServerInterpreter.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/TestServerInterpreter.scala @@ -5,6 +5,7 @@ import cats.effect.{IO, Resource} import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.CustomiseInterceptors import sttp.tapir.tests.Port +import scala.concurrent.duration.FiniteDuration trait TestServerInterpreter[F[_], +R, OPTIONS, ROUTE] { protected type Interceptors = CustomiseInterceptors[F, OPTIONS] => CustomiseInterceptors[F, OPTIONS] @@ -17,4 +18,9 @@ trait TestServerInterpreter[F[_], +R, OPTIONS, ROUTE] { def server(routes: NonEmptyList[ROUTE]): Resource[IO, Port] + /** Exposes additional `stop` effect, which allows stopping the server inside your test. It will be called after the test anyway (assuming + * idempotency), but may be useful for some cases where tests need to check specific behavior like returning 503s during shutdown. + */ + def serverWithStop(routes: NonEmptyList[ROUTE], gracefulShutdownTimeout: Option[FiniteDuration] = None): Resource[IO, (Port, IO[Unit])] = + server(routes).map(port => (port, IO.unit)) }