Skip to content

Commit

Permalink
Merge pull request #3256 from softwaremill/netty-server-side-cancel
Browse files Browse the repository at this point in the history
Netty server side cancellation
  • Loading branch information
adamw authored Oct 31, 2023
2 parents 7f4945e + bd8b94a commit c969f0f
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import cats.effect.{Async, IO, Resource}
import cats.syntax.all._
import io.netty.channel._
import io.netty.channel.unix.DomainSocketAddress
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala}
import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler}
import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}

import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import sttp.capabilities.fs2.Fs2Streams
import scala.concurrent.Future

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
Expand Down Expand Up @@ -53,6 +55,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
NettyCatsDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(block: () => F[ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) =
options.dispatcher.unsafeToFutureCancelable(block())

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): F[(SA, () => F[Unit])] = {
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[F] = new CatsMonadError[F]()
Expand All @@ -61,7 +66,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
multipart = false,
maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength)
)
.tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests()
.tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sttp.tapir.server.netty.cats

import cats.effect.IO
import cats.syntax.all._
import org.scalatest.matchers.should.Matchers._
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3._
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.tests.{CreateServerTest, _}
import sttp.tapir.tests._
import sttp.tapir.{CodecFormat, _}

import java.nio.charset.StandardCharsets
import scala.concurrent.duration._
import cats.effect.std.Queue
import cats.effect.unsafe.implicits.global

class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) {
import createServerTest._

implicit val m: MonadError[IO] = new CatsMonadError[IO]()
def tests(): List[Test] = List({
val buffer = Queue.unbounded[IO, Byte].unsafeRunSync()
val body_20_slowly_emitted_bytes =
fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100)
testServer(
endpoint.get
.in("streamCanceled")
.out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))),
"Client cancelling streaming triggers cancellation on the server"
)(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) =>

val expectedMaxAccumulated = 3

basicRequest
.get(uri"$baseUri/streamCanceled")
.send(backend)
.timeout(300.millis)
.attempt >>
IO.sleep(600.millis)
.flatMap(_ =>
buffer.size.flatMap(accumulated =>
IO(
assert(
accumulated <= expectedMaxAccumulated,
s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation."
)
)
)
)
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import scala.concurrent.{ExecutionContext, Future}
import sttp.tapir.server.model.ServerResponse

case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit
ec: ExecutionContext
Expand Down Expand Up @@ -49,6 +50,12 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe
NettyFutureDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(
block: () => Future[ServerResponse[NettyResponse]]
): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = {
(block(), () => Future.unit) // noop cancellation handler, we can't cancel native Futures
}

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = {
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[Future] = new FutureMonad()
Expand All @@ -57,7 +64,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,46 @@ import sttp.tapir.server.netty.NettyResponseContent.{
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}

import scala.collection.JavaConverters._
import scala.collection.mutable.{Queue => MutableQueue}
import scala.concurrent.Future
import scala.util.Failure
import scala.util.Success
import scala.util.control.NonFatal
import scala.concurrent.ExecutionContext

class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit
/** @param unsafeRunAsync
* Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() =>
* Future[Unit]` allowing cancellation of that Future. For example, this can be realized by
* `cats.effect.std.Dispatcher.unsafeToFutureCancelable`.
*/
class NettyServerHandler[F[_]](
route: Route[F],
unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]),
maxContentLength: Option[Int]
)(implicit
me: MonadError[F]
) extends SimpleChannelInboundHandler[HttpRequest] {

// Cancellation handling with eventLoopContext, lastResponseSent, and pendingResponses has been adapted
// from http4s: https://github.com/http4s/http4s-netty/pull/396/files
// By using the Netty event loop assigned to this channel we get two benefits:
// 1. We can avoid the necessary hopping around of threads since Netty pipelines will
// only pass events up and down from within the event loop to which it is assigned.
// That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise.
// 2. We get serialization of execution: the EventLoop is a serial execution queue so
// we can rest easy knowing that no two events will be executed in parallel.
private[this] var eventLoopContext: ExecutionContext = _

// This is used essentially as a queue, each incoming request attaches callbacks to this
// and replaces it to ensure that responses are written out in the same order that they came
// in.
private[this] var lastResponseSent: Future[Unit] = Future.unit

// We keep track of the cancellation tokens for all the requests in flight. This gives us
// observability into the number of requests in flight and the ability to cancel them all
// if the connection gets closed.
private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]]

private val logger = Logger[NettyServerHandler[F]]

private val EntityTooLarge: FullHttpResponse = {
Expand All @@ -40,33 +75,73 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])
res
}

override def handlerAdded(ctx: ChannelHandlerContext): Unit =
if (ctx.channel.isActive) {
initHandler(ctx)
}
override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx)

private[this] def initHandler(ctx: ChannelHandlerContext): Unit = {
if (eventLoopContext == null) {
// Initialize our ExecutionContext
eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop)

// When the channel closes we want to cancel any pending dispatches.
// Since the listener will be executed from the channels EventLoop everything is thread safe.
val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) =>
logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.")
pendingResponses.foreach(_.apply())
}
}
}

override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = {

def runRoute(req: HttpRequest) = {
def writeError500(req: HttpRequest, reason: Throwable): Unit = {
logger.error("Error while processing the request", reason)
// send 500
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.handleCloseAndKeepAliveHeaders(req)

route(NettyServerRequest(req))
.map {
case Some(response) => response
case None => ServerResponse.notFound
}
.flatMap((serverResponse: ServerResponse[NettyResponse]) =>
// in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that
// we get the 500 response, instead of dropping the request
try handleResponse(ctx, req, serverResponse).unit
catch {
case e: Exception => me.error[Unit](e)
}
)
.handleError { case ex: Exception =>
logger.error("Error while processing the request", ex)
// send 500
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.handleCloseAndKeepAliveHeaders(req)
ctx.writeAndFlush(res).closeIfNeeded(req)

ctx.writeAndFlush(res).closeIfNeeded(req)
me.unit(())
}
}

def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = {
val (runningFuture, cancellationSwitch) = unsafeRunAsync { () =>
route(NettyServerRequest(req))
.map {
case Some(response) => response
case None => ServerResponse.notFound
}
}
pendingResponses.enqueue(cancellationSwitch)
lastResponseSent = lastResponseSent.flatMap { _ =>
runningFuture.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
}
finally {
val _ = releaseReq()
}
case Failure(fatalException) => Failure(fatalException)
}(eventLoopContext)
}(eventLoopContext)
}

if (HttpUtil.is100ContinueExpected(request)) {
Expand All @@ -76,14 +151,9 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])
request match {
case full: FullHttpRequest =>
val req = full.retain()
unsafeRunAsync { () =>
runRoute(req)
.ensure(me.eval(req.release()))
} // exceptions should be handled
runRoute(req, () => req.release())
case req: StreamedHttpRequest =>
unsafeRunAsync { () =>
runRoute(req)
}
runRoute(req)
case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}")
}

Expand Down Expand Up @@ -156,22 +226,22 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])

if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) {
val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate())
future.addListener(new ChannelFutureListener() {
val _ = future.addListener(new ChannelFutureListener() {
override def operationComplete(future: ChannelFuture) = {
if (!future.isSuccess()) {
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause())
}
ctx.close()
val _ = ctx.close()
}
})
} else {
ctx
val _ = ctx
.writeAndFlush(EntityTooLarge.retainedDuplicate())
.addListener(new ChannelFutureListener() {
override def operationComplete(future: ChannelFuture) = {
if (!future.isSuccess()) {
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause())
ctx.close()
val _ = ctx.close()
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import io.netty.channel._
import io.netty.channel.unix.DomainSocketAddress
import sttp.capabilities.zio.ZioStreams
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler}
import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}
import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint}
import zio.{RIO, Unsafe, ZIO}

import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import scala.concurrent.ExecutionContext.Implicits
import scala.concurrent.Future

case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) {
def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se))
Expand Down Expand Up @@ -55,6 +58,17 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options:
NettyZioDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(
runtime: zio.Runtime[R]
)(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = {
val cancelable = Unsafe.unsafe(implicit u =>
runtime.unsafe.runToFuture(
block()
)
)
(cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global))
}

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for {
runtime <- ZIO.runtime[R]
routes <- ZIO.foreach(routes)(identity)
Expand All @@ -67,7 +81,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options:
config,
new NettyServerHandler[RIO[R, *]](
route,
(f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())),
unsafeRunAsync(runtime),
config.maxContentLength
),
eventLoopGroup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import sttp.tapir.ztapir.RIOMonadError
import zio.Task
import zio.interop.catz._

import scala.concurrent.Future

Expand All @@ -26,7 +27,8 @@ class NettyZioServerTest extends TestSuite with EitherValues {

val tests =
new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests()
new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Loading

0 comments on commit c969f0f

Please sign in to comment.