Skip to content

Commit

Permalink
Support graceful shutdown in Netty server (#3294)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Nov 7, 2023
1 parent fc1a336 commit cee01ca
Show file tree
Hide file tree
Showing 18 changed files with 445 additions and 103 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,8 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
scalaTest.value
),
libraryDependencies ++= loggerDependencies,
publishArtifact := false
publishArtifact := false,
Compile / run / fork := true
)
.jvmPlatform(scalaVersions = examplesScalaVersions)
.dependsOn(
Expand Down
15 changes: 15 additions & 0 deletions doc/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None)
NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256))
```

## Graceful shutdown

A Netty should can be gracefully closed using function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources.
You can customize this behavior in `NettyConfig`:

```scala mdoc:compile-only
import sttp.tapir.server.netty.NettyConfig
import scala.concurrent.duration._

// adjust the waiting time to your needs
val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds)
// or if you don't want the server to wait for in-flight requests
val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown
```

## Domain socket support

There is possibility to use Domain socket instead of TCP for handling traffic.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package sttp.tapir.examples

import cats.effect.IO
import cats.effect.IOApp
import sttp.client3.{HttpURLConnectionBackend, Identity, SttpBackend, UriContext, asStringAlways, basicRequest}
import sttp.model.StatusCode
import sttp.tapir.{PublicEndpoint, endpoint, query, stringBody}
import cats.effect.unsafe.implicits.global
import sttp.tapir.server.netty.cats.{NettyCatsServer, NettyCatsServerBinding}
import sttp.tapir.server.netty.cats.NettyCatsServer

object HelloWorldNettyCatsServer extends App {
object HelloWorldNettyCatsServer extends IOApp.Simple {
// One endpoint on GET /hello with query parameter `name`
val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] =
endpoint.get.in("hello").in(query[String]("name")).out(stringBody)
Expand All @@ -20,37 +20,37 @@ object HelloWorldNettyCatsServer extends App {
private val declaredHost = "localhost"

// Creating handler for netty bootstrap
NettyCatsServer
override def run = NettyCatsServer
.io()
.use { server =>

val effect: IO[NettyCatsServerBinding[IO]] = server
.port(declaredPort)
.host(declaredHost)
.addEndpoint(helloWorldServerEndpoint)
.start()

effect.map { binding =>

val port = binding.port
val host = binding.hostName
println(s"Server started at port = ${binding.port}")

val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend()
val badUrl = uri"http://$host:$port/bad_url"
assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404))

val noQueryParameter = uri"http://$host:$port/hello"
assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400))

val allGood = uri"http://$host:$port/hello?name=Netty"
val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body

println("Got result: " + body)
assert(body == "Hello, Netty!")
assert(port == declaredPort, "Ports don't match")
assert(host == declaredHost, "Hosts don't match")
}
for {
binding <- server
.port(declaredPort)
.host(declaredHost)
.addEndpoint(helloWorldServerEndpoint)
.start()
result <- IO
.blocking {
val port = binding.port
val host = binding.hostName
println(s"Server started at port = ${binding.port}")

val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend()
val badUrl = uri"http://$host:$port/bad_url"
assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404))

val noQueryParameter = uri"http://$host:$port/hello"
assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400))

val allGood = uri"http://$host:$port/hello?name=Netty"
val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body

println("Got result: " + body)
assert(body == "Hello, Netty!")
assert(port == declaredPort, "Ports don't match")
assert(host == declaredHost, "Hosts don't match")
}
.guarantee(binding.stop())
} yield result
}
.unsafeRunSync()
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package sttp.tapir.server.netty.cats

import cats.effect.kernel.Sync
import cats.effect.std.Dispatcher
import cats.effect.{Async, IO, Resource}
import cats.effect.{Async, IO, Resource, Temporal}
import cats.syntax.all._
import io.netty.channel._
import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup}
import io.netty.channel.unix.DomainSocketAddress
import io.netty.util.concurrent.DefaultEventExecutor
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
Expand All @@ -17,7 +20,9 @@ import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}
import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.Future
import scala.concurrent.duration._

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
Expand Down Expand Up @@ -62,26 +67,57 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[F] = new CatsMonadError[F]()
val route: Route[F] = Route.combine(routes)
val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe
val isShuttingDown: AtomicBoolean = new AtomicBoolean(false)

val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown),
eventLoopGroup,
socketOverride
)

nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup)))
nettyChannelFutureToScala(channelFuture).map(ch =>
(
ch.localAddress().asInstanceOf[SA],
() => stop(ch, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout)
)
)
}

private def stop(ch: Channel, eventLoopGroup: EventLoopGroup): F[Unit] = {
Async[F].defer {
nettyFutureToScala(ch.close()).flatMap { _ =>
if (config.shutdownEventLoopGroupOnClose) {
nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ())
} else Async[F].unit
}
private def waitForClosedChannels(
channelGroup: ChannelGroup,
startNanos: Long,
gracefulShutdownTimeoutNanos: Option[Long]
): F[Unit] =
if (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) {
Temporal[F].sleep(100.millis) >>
waitForClosedChannels(channelGroup, startNanos, gracefulShutdownTimeoutNanos)
} else {
Sync[F].delay(nettyFutureToScala(channelGroup.close())).void
}

private def stop(
ch: Channel,
eventLoopGroup: EventLoopGroup,
channelGroup: ChannelGroup,
isShuttingDown: AtomicBoolean,
gracefulShutdownTimeout: Option[FiniteDuration]
): F[Unit] = {
Sync[F].delay(isShuttingDown.set(true)) >>
waitForClosedChannels(
channelGroup,
startNanos = System.nanoTime(),
gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos)
) >>
Async[F].defer {
nettyFutureToScala(ch.close()).flatMap { _ =>
if (config.shutdownEventLoopGroupOnClose) {
nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ())
} else Async[F].unit
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration

class NettyCatsServerTest extends TestSuite with EitherValues {

Expand All @@ -23,6 +24,9 @@ class NettyCatsServerTest extends TestSuite with EitherValues {

val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher)
val createServerTest = new DefaultCreateServerTest(backend, interpreter)
val ioSleeper: Sleeper[IO] = new Sleeper[IO] {
override def sleep(duration: FiniteDuration): IO[Unit] = IO.sleep(duration)
}

val tests = new AllServerTests(
createServerTest,
Expand All @@ -34,7 +38,8 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
.tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests()
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.tests.TestServerInterpreter
import sttp.tapir.tests.Port
import sttp.capabilities.fs2.Fs2Streams
import scala.concurrent.duration.FiniteDuration

class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO])
extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] {
Expand All @@ -19,18 +20,28 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch
NettyCatsServerInterpreter(serverOptions).toRoute(es)
}

override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = {
override def serverWithStop(
routes: NonEmptyList[Route[IO]],
gracefulShutdownTimeout: Option[FiniteDuration] = None
): Resource[IO, (Port, IO[Unit])] = {
val config = NettyConfig.defaultWithStreaming
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.maxContentLength(NettyCatsTestServerInterpreter.maxContentLength)
.noGracefulShutdown

val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config)
val options = NettyCatsServerOptions.default[IO](dispatcher)
val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start()
val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, customizedConfig).addRoutes(routes.toList).start()

Resource
.make(bind)(_.stop())
.map(_.port)
.make(bind.map(b => (b, b.stop()))) { case (_, stop) => stop }
.map { case (b, stop) => (b.port, stop) }
}

override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = {
serverWithStop(routes).map(_._1)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ import scala.concurrent.duration._
* @param lingerTimeout
* Sets the delay for which the Netty waits, while data is being transmitted, before closing a socket after receiving a call to close the
* socket
*
* @param gracefulShutdownTimeout
* If set, attempts to wait for a given time for all in-flight requests to complete, before proceeding with shutting down the server. If
* `None`, closes the channels and terminates the server without waiting.
*/
case class NettyConfig(
host: String,
Expand All @@ -64,7 +68,8 @@ case class NettyConfig(
sslContext: Option[SslContext],
eventLoopConfig: EventLoopConfig,
socketConfig: NettySocketConfig,
initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit
initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit,
gracefulShutdownTimeout: Option[FiniteDuration]
) {
def host(h: String): NettyConfig = copy(host = h)

Expand Down Expand Up @@ -102,6 +107,9 @@ case class NettyConfig(
def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg))

def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f)

def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown = copy(gracefulShutdownTimeout = None)
}

object NettyConfig {
Expand All @@ -115,6 +123,7 @@ object NettyConfig {
connectionTimeout = Some(10.seconds),
socketTimeout = Some(60.seconds),
lingerTimeout = Some(60.seconds),
gracefulShutdownTimeout = Some(10.seconds),
maxContentLength = None,
maxConnections = None,
addLoggingHandler = false,
Expand Down
Loading

0 comments on commit cee01ca

Please sign in to comment.