Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graceful shutdown in Netty server #3294

Merged
merged 16 commits into from
Nov 7, 2023
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,8 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
scalaTest.value
),
libraryDependencies ++= loggerDependencies,
publishArtifact := false
publishArtifact := false,
Compile / run / fork := true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that temporary for development or a new requirement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful for development, when one wants to run examples with examples/runMain a.b.c.HelloExample and terminate a running server.

)
.jvmPlatform(scalaVersions = examplesScalaVersions)
.dependsOn(
Expand Down
15 changes: 15 additions & 0 deletions doc/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None)
NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256))
```

## Graceful shutdown

A Netty should can be gracefully closed using function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should can? ;) maybe just run it through chatgpt to iron out the english :)

You can customize this behavior in `NettyConfig`:

```scala mdoc:compile-only
import sttp.tapir.server.netty.NettyConfig
import scala.concurrent.duration._

// adjust the waiting time to your needs
val config = NettyConfig.defaultNoStreaming.withGracefulShutdownTimeout(5.seconds)
// or if you don't want the server to wait for in-flight requests
val config2 = NettyConfig.defaultNoStreaming.noGracefulShutdown
```

## Domain socket support

There is possibility to use Domain socket instead of TCP for handling traffic.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package sttp.tapir.examples

import cats.effect.IO
import cats.effect.IOApp
import sttp.client3.{HttpURLConnectionBackend, Identity, SttpBackend, UriContext, asStringAlways, basicRequest}
import sttp.model.StatusCode
import sttp.tapir.{PublicEndpoint, endpoint, query, stringBody}
import cats.effect.unsafe.implicits.global
import sttp.tapir.server.netty.cats.{NettyCatsServer, NettyCatsServerBinding}
import sttp.tapir.server.netty.cats.NettyCatsServer

object HelloWorldNettyCatsServer extends App {
object HelloWorldNettyCatsServer extends IOApp.Simple {
// One endpoint on GET /hello with query parameter `name`
val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] =
endpoint.get.in("hello").in(query[String]("name")).out(stringBody)
Expand All @@ -20,37 +20,37 @@ object HelloWorldNettyCatsServer extends App {
private val declaredHost = "localhost"

// Creating handler for netty bootstrap
NettyCatsServer
override def run = NettyCatsServer
.io()
.use { server =>

val effect: IO[NettyCatsServerBinding[IO]] = server
.port(declaredPort)
.host(declaredHost)
.addEndpoint(helloWorldServerEndpoint)
.start()

effect.map { binding =>

val port = binding.port
val host = binding.hostName
println(s"Server started at port = ${binding.port}")

val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend()
val badUrl = uri"http://$host:$port/bad_url"
assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404))

val noQueryParameter = uri"http://$host:$port/hello"
assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400))

val allGood = uri"http://$host:$port/hello?name=Netty"
val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body

println("Got result: " + body)
assert(body == "Hello, Netty!")
assert(port == declaredPort, "Ports don't match")
assert(host == declaredHost, "Hosts don't match")
}
for {
binding <- server
.port(declaredPort)
.host(declaredHost)
.addEndpoint(helloWorldServerEndpoint)
.start()
result <- IO
.blocking {
val port = binding.port
val host = binding.hostName
println(s"Server started at port = ${binding.port}")

val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend()
val badUrl = uri"http://$host:$port/bad_url"
assert(basicRequest.response(asStringAlways).get(badUrl).send(backend).code == StatusCode(404))

val noQueryParameter = uri"http://$host:$port/hello"
assert(basicRequest.response(asStringAlways).get(noQueryParameter).send(backend).code == StatusCode(400))

val allGood = uri"http://$host:$port/hello?name=Netty"
val body = basicRequest.response(asStringAlways).get(allGood).send(backend).body

println("Got result: " + body)
assert(body == "Hello, Netty!")
assert(port == declaredPort, "Ports don't match")
assert(host == declaredHost, "Hosts don't match")
}
.guarantee(binding.stop())
} yield result
}
.unsafeRunSync()
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package sttp.tapir.server.netty.cats

import cats.effect.kernel.Sync
import cats.effect.std.Dispatcher
import cats.effect.{Async, IO, Resource}
import cats.effect.{Async, IO, Resource, Temporal}
import cats.syntax.all._
import io.netty.channel._
import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup}
import io.netty.channel.unix.DomainSocketAddress
import io.netty.util.concurrent.DefaultEventExecutor
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
Expand All @@ -17,7 +20,9 @@ import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}
import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.Future
import scala.concurrent.duration._

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
Expand Down Expand Up @@ -62,26 +67,57 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[F] = new CatsMonadError[F]()
val route: Route[F] = Route.combine(routes)
val channelGroup = new DefaultChannelGroup(new DefaultEventExecutor()) // thread safe
val isShuttingDown: AtomicBoolean = new AtomicBoolean(false)

val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength, channelGroup, isShuttingDown),
eventLoopGroup,
socketOverride
)

nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup)))
nettyChannelFutureToScala(channelFuture).map(ch =>
(
ch.localAddress().asInstanceOf[SA],
() => stop(ch, eventLoopGroup, channelGroup, isShuttingDown, config.gracefulShutdownTimeout)
)
)
}

private def stop(ch: Channel, eventLoopGroup: EventLoopGroup): F[Unit] = {
Async[F].defer {
nettyFutureToScala(ch.close()).flatMap { _ =>
if (config.shutdownEventLoopGroupOnClose) {
nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ())
} else Async[F].unit
}
private def waitForClosedChannels(
channelGroup: ChannelGroup,
startNanos: Long,
gracefulShutdownTimeoutNanos: Option[Long]
): F[Unit] =
if (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) {
Temporal[F].sleep(100.millis) >>
waitForClosedChannels(channelGroup, startNanos, gracefulShutdownTimeoutNanos)
} else {
Sync[F].delay(nettyFutureToScala(channelGroup.close())).void
}

private def stop(
ch: Channel,
eventLoopGroup: EventLoopGroup,
channelGroup: ChannelGroup,
isShuttingDown: AtomicBoolean,
gracefulShutdownTimeout: Option[FiniteDuration]
): F[Unit] = {
Sync[F].delay(isShuttingDown.set(true)) >>
waitForClosedChannels(
channelGroup,
startNanos = System.nanoTime(),
gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos)
) >>
Async[F].defer {
nettyFutureToScala(ch.close()).flatMap { _ =>
if (config.shutdownEventLoopGroupOnClose) {
nettyFutureToScala(eventLoopGroup.shutdownGracefully()).map(_ => ())
} else Async[F].unit
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration

class NettyCatsServerTest extends TestSuite with EitherValues {

Expand All @@ -23,6 +24,9 @@ class NettyCatsServerTest extends TestSuite with EitherValues {

val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher)
val createServerTest = new DefaultCreateServerTest(backend, interpreter)
val ioSleeper: Sleeper[IO] = new Sleeper[IO] {
override def sleep(duration: FiniteDuration): IO[Unit] = IO.sleep(duration)
}

val tests = new AllServerTests(
createServerTest,
Expand All @@ -34,7 +38,8 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
.tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests()
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.tests.TestServerInterpreter
import sttp.tapir.tests.Port
import sttp.capabilities.fs2.Fs2Streams
import scala.concurrent.duration.FiniteDuration

class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO])
extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] {
Expand All @@ -19,18 +20,28 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch
NettyCatsServerInterpreter(serverOptions).toRoute(es)
}

override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = {
override def serverWithStop(
routes: NonEmptyList[Route[IO]],
gracefulShutdownTimeout: Option[FiniteDuration] = None
): Resource[IO, (Port, IO[Unit])] = {
val config = NettyConfig.defaultWithStreaming
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.maxContentLength(NettyCatsTestServerInterpreter.maxContentLength)
.noGracefulShutdown

val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config)
val options = NettyCatsServerOptions.default[IO](dispatcher)
val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start()
val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, customizedConfig).addRoutes(routes.toList).start()

Resource
.make(bind)(_.stop())
.map(_.port)
.make(bind.map(b => (b, b.stop()))) { case (_, stop) => stop }
.map { case (b, stop) => (b.port, stop) }
}

override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = {
serverWithStop(routes).map(_._1)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ import scala.concurrent.duration._
* @param lingerTimeout
* Sets the delay for which the Netty waits, while data is being transmitted, before closing a socket after receiving a call to close the
* socket
*
* @param gracefulShutdownTimeout
* If set, attempts to wait for a given time for all in-flight requests to complete, before proceeding with shutting down the server. If
* `None`, closes the channels and terminates the server without waiting.
*/
case class NettyConfig(
host: String,
Expand All @@ -64,7 +68,8 @@ case class NettyConfig(
sslContext: Option[SslContext],
eventLoopConfig: EventLoopConfig,
socketConfig: NettySocketConfig,
initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit
initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit,
gracefulShutdownTimeout: Option[FiniteDuration]
) {
def host(h: String): NettyConfig = copy(host = h)

Expand Down Expand Up @@ -102,6 +107,9 @@ case class NettyConfig(
def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg))

def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f)

def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown = copy(gracefulShutdownTimeout = None)
}

object NettyConfig {
Expand All @@ -115,6 +123,7 @@ object NettyConfig {
connectionTimeout = Some(10.seconds),
socketTimeout = Some(60.seconds),
lingerTimeout = Some(60.seconds),
gracefulShutdownTimeout = Some(10.seconds),
maxContentLength = None,
maxConnections = None,
addLoggingHandler = false,
Expand Down
Loading
Loading