Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for server-side request cancellation #396

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions server/src/main/scala/org/http4s/netty/server/Http4sNettyHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ import java.time.Instant
import java.time.ZoneId
import java.time.format.DateTimeFormatter
import java.util.Locale
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{Queue => MutableQueue}
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success
import scala.util.control.NoStackTrace
Expand Down Expand Up @@ -77,25 +76,18 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl
// That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise.
// 2. We get serialization of execution: the EventLoop is a serial execution queue so
// we can rest easy knowing that no two events will be executed in parallel.
private[this] var eventLoopContext_ : ExecutionContext = _

// Note that this must be called from within the handlers EventLoop
private[this] def getEventLoopExecutionContext(ctx: ChannelHandlerContext): ExecutionContext = {
if (eventLoopContext_ == null) {
eventLoopContext_ = ExecutionContext.fromExecutor(ctx.channel.eventLoop)
}
eventLoopContext_
}

// We keep track of whether there are requests in flight. If there are, we don't respond to read
// complete, since back pressure is the responsibility of the streams.
private[this] val requestsInFlight = new AtomicLong()
private[this] var eventLoopContext: ExecutionContext = _

// This is used essentially as a queue, each incoming request attaches callbacks to this
// and replaces it to ensure that responses are written out in the same order that they came
// in.
private[this] var lastResponseSent: Future[Unit] = Future.unit

// We keep track of the cancellation tokens for all the requests in flight. This gives us
// observability into the number of requests in flight and the ability to cancel them all
// if the connection gets closed.
private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]]

// Compute the formatted date string only once per second, and cache the result.
// This should help microscopically under load.
private[this] var cachedDate: Long = Long.MinValue
Expand All @@ -113,7 +105,6 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl

override def channelRead(ctx: ChannelHandlerContext, msg: Object): Unit = {
logger.trace(s"channelRead: ctx = $ctx, msg = $msg")
val eventLoopContext = getEventLoopExecutionContext(ctx)
val newTick = System.currentTimeMillis() / 1000
if (cachedDate < newTick) {
cachedDateString = RFC7231InstantFormatter.format(Instant.ofEpochSecond(newTick))
Expand All @@ -122,45 +113,41 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl

msg match {
case req: HttpRequest =>
requestsInFlight.incrementAndGet()
val p: Promise[(HttpResponse, F[Unit])] =
Promise[(HttpResponse, F[Unit])]()

val reqAndCleanup = handle(ctx.channel(), req, cachedDateString).allocated
// Start execution of the handler.
disp.unsafeRunAndForget(reqAndCleanup.attempt.flatMap {
case Right(result) => F.delay(p.success(result))
case Left(err) => F.delay(p.failure(err))
})

val (f, cancelRequest) = disp.unsafeToFutureCancelable(reqAndCleanup)
pendingResponses.enqueue(cancelRequest)

// This attaches all writes sequentially using
// LastResponseSent as a queue. `eventLoopContext` ensures we do not
// CTX switch the writes.
lastResponseSent = lastResponseSent.flatMap[Unit] { _ =>
p.future
.transform {
case Success((response, cleanup)) =>
if (requestsInFlight.decrementAndGet() == 0)
// Since we've now gone down to zero, we need to issue a
// read, in case we ignored an earlier read complete
ctx.read()
void {
ctx
.writeAndFlush(response)
.addListener((_: ChannelFuture) => disp.unsafeRunAndForget(cleanup))
}
Success(())

case Failure(NonFatal(e)) =>
logger.warn(e)(
"Error caught during service handling. Check the configured ServiceErrorHandler.")
void {
sendSimpleErrorResponse(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR)
}
Failure(e)

case Failure(e) => // fatal: just let it go.
Failure(e)
}(eventLoopContext)
f.transform {
case Success((response, cleanup)) =>
pendingResponses.dequeue()
if (pendingResponses.isEmpty)
// Since we've now gone down to zero, we need to issue a
// read, in case we ignored an earlier read complete
ctx.read()
void {
ctx
.writeAndFlush(response)
.addListener((_: ChannelFuture) => disp.unsafeRunAndForget(cleanup))
}
Success(())

case Failure(NonFatal(e)) =>
logger.warn(e)(
"Error caught during service handling. Check the configured ServiceErrorHandler.")
void {
sendSimpleErrorResponse(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR)
}
Failure(e)

case Failure(e) => // fatal: just let it go.
Failure(e)
}(eventLoopContext)
}(eventLoopContext)

case LastHttpContent.EMPTY_LAST_CONTENT =>
Expand All @@ -182,7 +169,7 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl
// we don't get in the way of the request body reactive streams,
// which will be using channel read complete and read to implement
// their own back pressure
if (requestsInFlight.get() == 0) {
if (pendingResponses.isEmpty) {
ctx.read()
} else {
// otherwise forward it, so that any handler publishers downstream
Expand Down Expand Up @@ -210,12 +197,12 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl
}
}

override def channelActive(ctx: ChannelHandlerContext): Unit = void {
// AUTO_READ is off, so need to do the first read explicitly.
// this method is called when the channel is registered with the event loop,
// so ctx.read is automatically safe here w/o needing an isRegistered().
ctx.read()
}
override def handlerAdded(ctx: ChannelHandlerContext): Unit =
if (ctx.channel.isActive) {
initHandler(ctx)
}

override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx)

override def userEventTriggered(ctx: ChannelHandlerContext, evt: scala.Any): Unit = void {
evt match {
Expand All @@ -226,7 +213,27 @@ private[netty] abstract class Http4sNettyHandler[F[_]](disp: Dispatcher[F])(impl
}
}

private def sendSimpleErrorResponse(
private[this] def initHandler(ctx: ChannelHandlerContext): Unit =
// Guard against double initialization. It shouldn't matter, but might as well be safe.
if (eventLoopContext == null) void {
// Initialize our ExecutionContext
eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop)

// When the channel closes we want to cancel any pending dispatches.
// Since the listener will be executed from the channels EventLoop everything is thread safe.
ctx.channel.closeFuture.addListener { (_: ChannelFuture) =>
logger.debug(
s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.")
pendingResponses.foreach(_.apply())
}

// AUTO_READ is off, so need to do the first read explicitly.
// this method is called when the channel is registered with the event loop,
// so ctx.read is automatically safe here w/o needing an isRegistered().
ctx.read()
}

private[this] def sendSimpleErrorResponse(
ctx: ChannelHandlerContext,
status: HttpResponseStatus): ChannelFuture = {
val response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status)
Expand Down
30 changes: 30 additions & 0 deletions server/src/test/scala/org/http4s/netty/server/ServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.http4s.netty.server

import cats.data.Kleisli
import cats.effect.Deferred
import cats.effect.IO
import cats.effect.Resource
import cats.implicits._
Expand Down Expand Up @@ -117,6 +118,35 @@ abstract class ServerTest extends IOSuite {
}
}
}

test("requests can be cancelled") {
val ref: Deferred[IO, Boolean] = Deferred.unsafe[IO, Boolean]
val route = HttpRoutes
.of[IO] { case GET -> Root / "cancel" =>
(IO.never *> IO.defer(Ok(""))).onCancel(ref.complete(true).void)
}
.orNotFound

val server: Resource[IO, Server] = NettyServerBuilder[IO]
.withHttpApp(route)
.withEventLoopThreads(1)
.withIdleTimeout(
1.seconds
) // Basically going to send the request and hope it times out immediately.
.withoutBanner
.bindAny()
.resource

server.use { server =>
val uri = server.baseUri / "cancel"
val resp = client().statusFromUri(uri).timeout(15.seconds).attempt.map { case other =>
fail(s"unexpectedly received a result: $other")
}

IO.race(resp, ref.get.map(assert(_)))
.map(_.merge)
}
}
}

class JDKServerTest extends ServerTest {
Expand Down