-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Netty server side cancellation #3256
Changes from all commits
b265784
a7a5669
d77a527
64ff166
c1ef9b4
239068e
ef26db4
8b047e5
120a2da
94d2b3f
c877bc6
68b8684
0c7956f
0bd6f49
ffa9b61
883698c
4790c80
995e6e0
bd8b94a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package sttp.tapir.server.netty.cats | ||
|
||
import cats.effect.IO | ||
import cats.syntax.all._ | ||
import org.scalatest.matchers.should.Matchers._ | ||
import sttp.capabilities.fs2.Fs2Streams | ||
import sttp.client3._ | ||
import sttp.monad.MonadError | ||
import sttp.tapir.integ.cats.effect.CatsMonadError | ||
import sttp.tapir.server.tests.{CreateServerTest, _} | ||
import sttp.tapir.tests._ | ||
import sttp.tapir.{CodecFormat, _} | ||
|
||
import java.nio.charset.StandardCharsets | ||
import scala.concurrent.duration._ | ||
import cats.effect.std.Queue | ||
import cats.effect.unsafe.implicits.global | ||
|
||
class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) { | ||
import createServerTest._ | ||
|
||
implicit val m: MonadError[IO] = new CatsMonadError[IO]() | ||
def tests(): List[Test] = List({ | ||
val buffer = Queue.unbounded[IO, Byte].unsafeRunSync() | ||
val body_20_slowly_emitted_bytes = | ||
fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100) | ||
testServer( | ||
endpoint.get | ||
.in("streamCanceled") | ||
.out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))), | ||
"Client cancelling streaming triggers cancellation on the server" | ||
)(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) => | ||
|
||
val expectedMaxAccumulated = 3 | ||
|
||
basicRequest | ||
.get(uri"$baseUri/streamCanceled") | ||
.send(backend) | ||
.timeout(300.millis) | ||
.attempt >> | ||
IO.sleep(600.millis) | ||
.flatMap(_ => | ||
buffer.size.flatMap(accumulated => | ||
IO( | ||
assert( | ||
accumulated <= expectedMaxAccumulated, | ||
s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation." | ||
) | ||
) | ||
) | ||
) | ||
} | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,46 @@ import sttp.tapir.server.netty.NettyResponseContent.{ | |
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable.{Queue => MutableQueue} | ||
import scala.concurrent.Future | ||
import scala.util.Failure | ||
import scala.util.Success | ||
import scala.util.control.NonFatal | ||
import scala.concurrent.ExecutionContext | ||
|
||
class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit | ||
/** @param unsafeRunAsync | ||
* Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => | ||
* Future[Unit]` allowing cancellation of that Future. For example, this can be realized by | ||
* `cats.effect.std.Dispatcher.unsafeToFutureCancelable`. | ||
*/ | ||
class NettyServerHandler[F[_]]( | ||
route: Route[F], | ||
unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), | ||
maxContentLength: Option[Int] | ||
)(implicit | ||
me: MonadError[F] | ||
) extends SimpleChannelInboundHandler[HttpRequest] { | ||
|
||
// Cancellation handling with eventLoopContext, lastResponseSent, and pendingResponses has been adapted | ||
// from http4s: https://github.com/http4s/http4s-netty/pull/396/files | ||
// By using the Netty event loop assigned to this channel we get two benefits: | ||
// 1. We can avoid the necessary hopping around of threads since Netty pipelines will | ||
// only pass events up and down from within the event loop to which it is assigned. | ||
// That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise. | ||
// 2. We get serialization of execution: the EventLoop is a serial execution queue so | ||
// we can rest easy knowing that no two events will be executed in parallel. | ||
private[this] var eventLoopContext: ExecutionContext = _ | ||
|
||
// This is used essentially as a queue, each incoming request attaches callbacks to this | ||
// and replaces it to ensure that responses are written out in the same order that they came | ||
// in. | ||
private[this] var lastResponseSent: Future[Unit] = Future.unit | ||
|
||
// We keep track of the cancellation tokens for all the requests in flight. This gives us | ||
// observability into the number of requests in flight and the ability to cancel them all | ||
// if the connection gets closed. | ||
private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you manage to drill down and understand why this is a queue - can you have multiple ongoing requests? is http 1 only? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly this is HTTP2, so netty can ingest a request, dispatch async processing to another thread pool, and pick up next request. The responses will be returned in order though, even if request 2 is finished before request 1. https://medium.com/@akhaku/netty-data-model-threading-and-gotchas-cab820e4815a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit: Turns out HTTP2 multiplexing is a different beast, which requires special setup of Netty, even an additional library ( With this code, we are implementing support for cancellation in HTTP 1.1 pipelining: where a client can send multiple requests without waiting for the response to the first, and the server will process and respond to them in order. One of the main challenges is that responses must be returned in order, which can introduce head-of-line blocking if processing one request takes longer than others. HTTP/2 addresses these issues by introducing mentioned multiplexing. |
||
|
||
private val logger = Logger[NettyServerHandler[F]] | ||
|
||
private val EntityTooLarge: FullHttpResponse = { | ||
|
@@ -40,33 +75,73 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) | |
res | ||
} | ||
|
||
override def handlerAdded(ctx: ChannelHandlerContext): Unit = | ||
if (ctx.channel.isActive) { | ||
initHandler(ctx) | ||
} | ||
override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it so that either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed sometimes both can be called, but now I've added a check in initHandler which will prevent double listener registration. |
||
|
||
private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { | ||
if (eventLoopContext == null) { | ||
// Initialize our ExecutionContext | ||
eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) | ||
|
||
// When the channel closes we want to cancel any pending dispatches. | ||
// Since the listener will be executed from the channels EventLoop everything is thread safe. | ||
val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => | ||
logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") | ||
pendingResponses.foreach(_.apply()) | ||
} | ||
} | ||
} | ||
|
||
override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { | ||
|
||
def runRoute(req: HttpRequest) = { | ||
def writeError500(req: HttpRequest, reason: Throwable): Unit = { | ||
logger.error("Error while processing the request", reason) | ||
// send 500 | ||
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) | ||
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) | ||
res.handleCloseAndKeepAliveHeaders(req) | ||
|
||
route(NettyServerRequest(req)) | ||
.map { | ||
case Some(response) => response | ||
case None => ServerResponse.notFound | ||
} | ||
.flatMap((serverResponse: ServerResponse[NettyResponse]) => | ||
// in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that | ||
// we get the 500 response, instead of dropping the request | ||
try handleResponse(ctx, req, serverResponse).unit | ||
catch { | ||
case e: Exception => me.error[Unit](e) | ||
} | ||
) | ||
.handleError { case ex: Exception => | ||
logger.error("Error while processing the request", ex) | ||
// send 500 | ||
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) | ||
res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) | ||
res.handleCloseAndKeepAliveHeaders(req) | ||
ctx.writeAndFlush(res).closeIfNeeded(req) | ||
|
||
ctx.writeAndFlush(res).closeIfNeeded(req) | ||
me.unit(()) | ||
} | ||
} | ||
|
||
def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { | ||
val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => | ||
route(NettyServerRequest(req)) | ||
.map { | ||
case Some(response) => response | ||
case None => ServerResponse.notFound | ||
} | ||
} | ||
pendingResponses.enqueue(cancellationSwitch) | ||
lastResponseSent = lastResponseSent.flatMap { _ => | ||
runningFuture.transform { | ||
case Success(serverResponse) => | ||
pendingResponses.dequeue() | ||
try { | ||
handleResponse(ctx, req, serverResponse) | ||
Success(()) | ||
} catch { | ||
case NonFatal(ex) => | ||
writeError500(req, ex) | ||
Failure(ex) | ||
} finally { | ||
val _ = releaseReq() | ||
} | ||
case Failure(NonFatal(ex)) => | ||
try { | ||
writeError500(req, ex) | ||
Failure(ex) | ||
} | ||
finally { | ||
val _ = releaseReq() | ||
} | ||
case Failure(fatalException) => Failure(fatalException) | ||
}(eventLoopContext) | ||
}(eventLoopContext) | ||
} | ||
|
||
if (HttpUtil.is100ContinueExpected(request)) { | ||
|
@@ -76,14 +151,9 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) | |
request match { | ||
case full: FullHttpRequest => | ||
val req = full.retain() | ||
unsafeRunAsync { () => | ||
runRoute(req) | ||
.ensure(me.eval(req.release())) | ||
} // exceptions should be handled | ||
runRoute(req, () => req.release()) | ||
case req: StreamedHttpRequest => | ||
unsafeRunAsync { () => | ||
runRoute(req) | ||
} | ||
runRoute(req) | ||
case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") | ||
} | ||
|
||
|
@@ -156,22 +226,22 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) | |
|
||
if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { | ||
val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) | ||
future.addListener(new ChannelFutureListener() { | ||
val _ = future.addListener(new ChannelFutureListener() { | ||
override def operationComplete(future: ChannelFuture) = { | ||
if (!future.isSuccess()) { | ||
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) | ||
} | ||
ctx.close() | ||
val _ = ctx.close() | ||
} | ||
}) | ||
} else { | ||
ctx | ||
val _ = ctx | ||
.writeAndFlush(EntityTooLarge.retainedDuplicate()) | ||
.addListener(new ChannelFutureListener() { | ||
override def operationComplete(future: ChannelFuture) = { | ||
if (!future.isSuccess()) { | ||
logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) | ||
ctx.close() | ||
val _ = ctx.close() | ||
} | ||
} | ||
}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,18 @@ import io.netty.channel._ | |
import io.netty.channel.unix.DomainSocketAddress | ||
import sttp.capabilities.zio.ZioStreams | ||
import sttp.tapir.server.ServerEndpoint | ||
import sttp.tapir.server.netty.{NettyConfig, Route} | ||
import sttp.tapir.server.model.ServerResponse | ||
import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} | ||
import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} | ||
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} | ||
import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} | ||
import zio.{RIO, Unsafe, ZIO} | ||
|
||
import java.net.{InetSocketAddress, SocketAddress} | ||
import java.nio.file.{Path, Paths} | ||
import java.util.UUID | ||
import scala.concurrent.ExecutionContext.Implicits | ||
import scala.concurrent.Future | ||
|
||
case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { | ||
def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) | ||
|
@@ -55,6 +58,17 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: | |
NettyZioDomainSocketBinding(socket, stop) | ||
} | ||
|
||
private def unsafeRunAsync( | ||
runtime: zio.Runtime[R] | ||
)(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { | ||
val cancelable = Unsafe.unsafe(implicit u => | ||
runtime.unsafe.runToFuture( | ||
block() | ||
) | ||
) | ||
(cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global)) | ||
} | ||
|
||
private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for { | ||
runtime <- ZIO.runtime[R] | ||
routes <- ZIO.foreach(routes)(identity) | ||
|
@@ -67,7 +81,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: | |
config, | ||
new NettyServerHandler[RIO[R, *]]( | ||
route, | ||
(f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())), | ||
unsafeRunAsync(runtime), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can add a test for either ZIO/cats, that a long-running request is indeed cancelled (sth similar was present in the http4s PR) |
||
config.maxContentLength | ||
), | ||
eventLoopGroup, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe mention here that this is copied from http4's code, just to maintain proper attribution :)