Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Netty server side cancellation #3256

Merged
merged 19 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import cats.effect.{Async, IO, Resource}
import cats.syntax.all._
import io.netty.channel._
import io.netty.channel.unix.DomainSocketAddress
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala}
import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler}
import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}

import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import sttp.capabilities.fs2.Fs2Streams
import scala.concurrent.Future

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
Expand Down Expand Up @@ -53,6 +55,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
NettyCatsDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(block: () => F[ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) =
options.dispatcher.unsafeToFutureCancelable(block())

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): F[(SA, () => F[Unit])] = {
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[F] = new CatsMonadError[F]()
Expand All @@ -61,7 +66,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
multipart = false,
maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength)
)
.tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests()
.tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sttp.tapir.server.netty.cats

import cats.effect.IO
import cats.syntax.all._
import org.scalatest.matchers.should.Matchers._
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3._
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.tests.{CreateServerTest, _}
import sttp.tapir.tests._
import sttp.tapir.{CodecFormat, _}

import java.nio.charset.StandardCharsets
import scala.concurrent.duration._
import cats.effect.std.Queue
import cats.effect.unsafe.implicits.global

class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) {
import createServerTest._

implicit val m: MonadError[IO] = new CatsMonadError[IO]()
def tests(): List[Test] = List({
val buffer = Queue.unbounded[IO, Byte].unsafeRunSync()
val body_20_slowly_emitted_bytes =
fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100)
testServer(
endpoint.get
.in("streamCanceled")
.out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))),
"Client cancelling streaming triggers cancellation on the server"
)(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) =>

val expectedMaxAccumulated = 3

basicRequest
.get(uri"$baseUri/streamCanceled")
.send(backend)
.timeout(300.millis)
.attempt >>
IO.sleep(600.millis)
.flatMap(_ =>
buffer.size.flatMap(accumulated =>
IO(
assert(
accumulated <= expectedMaxAccumulated,
s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation."
)
)
)
)
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import scala.concurrent.{ExecutionContext, Future}
import sttp.tapir.server.model.ServerResponse

case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit
ec: ExecutionContext
Expand Down Expand Up @@ -49,6 +50,12 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe
NettyFutureDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(
block: () => Future[ServerResponse[NettyResponse]]
): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = {
(block(), () => Future.unit) // noop cancellation handler, we can't cancel native Futures
}

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = {
val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup()
implicit val monadError: MonadError[Future] = new FutureMonad()
Expand All @@ -57,7 +64,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength),
new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,46 @@ import sttp.tapir.server.netty.NettyResponseContent.{
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}

import scala.collection.JavaConverters._
import scala.collection.mutable.{Queue => MutableQueue}
import scala.concurrent.Future
import scala.util.Failure
import scala.util.Success
import scala.util.control.NonFatal
import scala.concurrent.ExecutionContext

class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit
/** @param unsafeRunAsync
* Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() =>
* Future[Unit]` allowing cancellation of that Future. For example, this can be realized by
* `cats.effect.std.Dispatcher.unsafeToFutureCancelable`.
*/
class NettyServerHandler[F[_]](
route: Route[F],
unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]),
maxContentLength: Option[Int]
)(implicit
me: MonadError[F]
) extends SimpleChannelInboundHandler[HttpRequest] {

// Cancellation handling with eventLoopContext, lastResponseSent, and pendingResponses has been adapted
// from http4s: https://github.com/http4s/http4s-netty/pull/396/files
// By using the Netty event loop assigned to this channel we get two benefits:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe mention here that this is copied from http4's code, just to maintain proper attribution :)

// 1. We can avoid the necessary hopping around of threads since Netty pipelines will
// only pass events up and down from within the event loop to which it is assigned.
// That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise.
// 2. We get serialization of execution: the EventLoop is a serial execution queue so
// we can rest easy knowing that no two events will be executed in parallel.
private[this] var eventLoopContext: ExecutionContext = _

// This is used essentially as a queue, each incoming request attaches callbacks to this
// and replaces it to ensure that responses are written out in the same order that they came
// in.
private[this] var lastResponseSent: Future[Unit] = Future.unit

// We keep track of the cancellation tokens for all the requests in flight. This gives us
// observability into the number of requests in flight and the ability to cancel them all
// if the connection gets closed.
private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you manage to drill down and understand why this is a queue - can you have multiple ongoing requests? is http 1 only?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly this is HTTP2, so netty can ingest a request, dispatch async processing to another thread pool, and pick up next request. The responses will be returned in order though, even if request 2 is finished before request 1. https://medium.com/@akhaku/netty-data-model-threading-and-gotchas-cab820e4815a

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Turns out HTTP2 multiplexing is a different beast, which requires special setup of Netty, even an additional library (netty-codec-http2). It's powerful, because it allows opening a single connection and sending multiple requests and getting responses in any order, the protocol should take care of this. However, it's something different than what our server is capable of.

With this code, we are implementing support for cancellation in HTTP 1.1 pipelining: where a client can send multiple requests without waiting for the response to the first, and the server will process and respond to them in order. One of the main challenges is that responses must be returned in order, which can introduce head-of-line blocking if processing one request takes longer than others. HTTP/2 addresses these issues by introducing mentioned multiplexing.


private val logger = Logger[NettyServerHandler[F]]

private val EntityTooLarge: FullHttpResponse = {
Expand All @@ -40,33 +75,73 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])
res
}

override def handlerAdded(ctx: ChannelHandlerContext): Unit =
if (ctx.channel.isActive) {
initHandler(ctx)
}
override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it so that either channelActive OR handlerAdded is called? won't both be called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed sometimes both can be called, but now I've added a check in initHandler which will prevent double listener registration.


private[this] def initHandler(ctx: ChannelHandlerContext): Unit = {
if (eventLoopContext == null) {
// Initialize our ExecutionContext
eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop)

// When the channel closes we want to cancel any pending dispatches.
// Since the listener will be executed from the channels EventLoop everything is thread safe.
val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) =>
logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.")
pendingResponses.foreach(_.apply())
}
}
}

override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = {

def runRoute(req: HttpRequest) = {
def writeError500(req: HttpRequest, reason: Throwable): Unit = {
logger.error("Error while processing the request", reason)
// send 500
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.handleCloseAndKeepAliveHeaders(req)

route(NettyServerRequest(req))
.map {
case Some(response) => response
case None => ServerResponse.notFound
}
.flatMap((serverResponse: ServerResponse[NettyResponse]) =>
// in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that
// we get the 500 response, instead of dropping the request
try handleResponse(ctx, req, serverResponse).unit
catch {
case e: Exception => me.error[Unit](e)
}
)
.handleError { case ex: Exception =>
logger.error("Error while processing the request", ex)
// send 500
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0)
res.handleCloseAndKeepAliveHeaders(req)
ctx.writeAndFlush(res).closeIfNeeded(req)

ctx.writeAndFlush(res).closeIfNeeded(req)
me.unit(())
}
}

def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = {
val (runningFuture, cancellationSwitch) = unsafeRunAsync { () =>
route(NettyServerRequest(req))
.map {
case Some(response) => response
case None => ServerResponse.notFound
}
}
pendingResponses.enqueue(cancellationSwitch)
lastResponseSent = lastResponseSent.flatMap { _ =>
runningFuture.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
}
finally {
val _ = releaseReq()
}
case Failure(fatalException) => Failure(fatalException)
}(eventLoopContext)
}(eventLoopContext)
}

if (HttpUtil.is100ContinueExpected(request)) {
Expand All @@ -76,14 +151,9 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])
request match {
case full: FullHttpRequest =>
val req = full.retain()
unsafeRunAsync { () =>
runRoute(req)
.ensure(me.eval(req.release()))
} // exceptions should be handled
runRoute(req, () => req.release())
case req: StreamedHttpRequest =>
unsafeRunAsync { () =>
runRoute(req)
}
runRoute(req)
case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}")
}

Expand Down Expand Up @@ -156,22 +226,22 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit])

if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) {
val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate())
future.addListener(new ChannelFutureListener() {
val _ = future.addListener(new ChannelFutureListener() {
override def operationComplete(future: ChannelFuture) = {
if (!future.isSuccess()) {
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause())
}
ctx.close()
val _ = ctx.close()
}
})
} else {
ctx
val _ = ctx
.writeAndFlush(EntityTooLarge.retainedDuplicate())
.addListener(new ChannelFutureListener() {
override def operationComplete(future: ChannelFuture) = {
if (!future.isSuccess()) {
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause())
ctx.close()
val _ = ctx.close()
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import io.netty.channel._
import io.netty.channel.unix.DomainSocketAddress
import sttp.capabilities.zio.ZioStreams
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler}
import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route}
import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint}
import zio.{RIO, Unsafe, ZIO}

import java.net.{InetSocketAddress, SocketAddress}
import java.nio.file.{Path, Paths}
import java.util.UUID
import scala.concurrent.ExecutionContext.Implicits
import scala.concurrent.Future

case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) {
def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se))
Expand Down Expand Up @@ -55,6 +58,17 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options:
NettyZioDomainSocketBinding(socket, stop)
}

private def unsafeRunAsync(
runtime: zio.Runtime[R]
)(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = {
val cancelable = Unsafe.unsafe(implicit u =>
runtime.unsafe.runToFuture(
block()
)
)
(cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global))
}

private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for {
runtime <- ZIO.runtime[R]
routes <- ZIO.foreach(routes)(identity)
Expand All @@ -67,7 +81,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options:
config,
new NettyServerHandler[RIO[R, *]](
route,
(f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())),
unsafeRunAsync(runtime),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add a test for either ZIO/cats, that a long-running request is indeed cancelled (sth similar was present in the http4s PR)

config.maxContentLength
),
eventLoopGroup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import sttp.tapir.ztapir.RIOMonadError
import zio.Task
import zio.interop.catz._

import scala.concurrent.Future

Expand All @@ -26,7 +27,8 @@ class NettyZioServerTest extends TestSuite with EitherValues {

val tests =
new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests()
new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Loading
Loading