diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 604a13e57d..7e7b3c8fea 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -5,17 +5,19 @@ import cats.effect.{Async, IO, Resource} import cats.syntax.all._ import io.netty.channel._ import io.netty.channel.unix.DomainSocketAddress +import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} -import sttp.tapir.server.netty.{NettyConfig, Route} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID -import sttp.capabilities.fs2.Fs2Streams +import scala.concurrent.Future case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) @@ -53,6 +55,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyCatsDomainSocketBinding(socket, stop) } + private def unsafeRunAsync(block: () => F[ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = + options.dispatcher.unsafeToFutureCancelable(block()) + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): F[(SA, () => F[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[F] = new CatsMonadError[F]() @@ -61,7 +66,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 84e90e1266..7b3f2a1303 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -31,7 +31,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues { multipart = false, maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) ) - .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() + .tests() ++ + new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ + new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ + new NettyFs2StreamingCancellationTest(createServerTest).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala new file mode 100644 index 0000000000..df25c4fefc --- /dev/null +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala @@ -0,0 +1,54 @@ +package sttp.tapir.server.netty.cats + +import cats.effect.IO +import cats.syntax.all._ +import org.scalatest.matchers.should.Matchers._ +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.monad.MonadError +import sttp.tapir.integ.cats.effect.CatsMonadError +import sttp.tapir.server.tests.{CreateServerTest, _} +import sttp.tapir.tests._ +import sttp.tapir.{CodecFormat, _} + +import java.nio.charset.StandardCharsets +import scala.concurrent.duration._ +import cats.effect.std.Queue +import cats.effect.unsafe.implicits.global + +class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) { + import createServerTest._ + + implicit val m: MonadError[IO] = new CatsMonadError[IO]() + def tests(): List[Test] = List({ + val buffer = Queue.unbounded[IO, Byte].unsafeRunSync() + val body_20_slowly_emitted_bytes = + fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100) + testServer( + endpoint.get + .in("streamCanceled") + .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))), + "Client cancelling streaming triggers cancellation on the server" + )(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) => + + val expectedMaxAccumulated = 3 + + basicRequest + .get(uri"$baseUri/streamCanceled") + .send(backend) + .timeout(300.millis) + .attempt >> + IO.sleep(600.millis) + .flatMap(_ => + buffer.size.flatMap(accumulated => + IO( + assert( + accumulated <= expectedMaxAccumulated, + s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation." + ) + ) + ) + ) + } + }) +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index b53483a0a9..d92c704b02 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -11,6 +11,7 @@ import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID import scala.concurrent.{ExecutionContext, Future} +import sttp.tapir.server.model.ServerResponse case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext @@ -49,6 +50,12 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyFutureDomainSocketBinding(socket, stop) } + private def unsafeRunAsync( + block: () => Future[ServerResponse[NettyResponse]] + ): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { + (block(), () => Future.unit) // noop cancellation handler, we can't cancel native Futures + } + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[Future] = new FutureMonad() @@ -57,7 +64,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 6a763453d2..d5bf90f4c7 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -20,11 +20,46 @@ import sttp.tapir.server.netty.NettyResponseContent.{ import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import scala.collection.JavaConverters._ +import scala.collection.mutable.{Queue => MutableQueue} +import scala.concurrent.Future +import scala.util.Failure +import scala.util.Success +import scala.util.control.NonFatal +import scala.concurrent.ExecutionContext -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit +/** @param unsafeRunAsync + * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => + * Future[Unit]` allowing cancellation of that Future. For example, this can be realized by + * `cats.effect.std.Dispatcher.unsafeToFutureCancelable`. + */ +class NettyServerHandler[F[_]]( + route: Route[F], + unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), + maxContentLength: Option[Int] +)(implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { + // Cancellation handling with eventLoopContext, lastResponseSent, and pendingResponses has been adapted + // from http4s: https://github.com/http4s/http4s-netty/pull/396/files + // By using the Netty event loop assigned to this channel we get two benefits: + // 1. We can avoid the necessary hopping around of threads since Netty pipelines will + // only pass events up and down from within the event loop to which it is assigned. + // That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise. + // 2. We get serialization of execution: the EventLoop is a serial execution queue so + // we can rest easy knowing that no two events will be executed in parallel. + private[this] var eventLoopContext: ExecutionContext = _ + + // This is used essentially as a queue, each incoming request attaches callbacks to this + // and replaces it to ensure that responses are written out in the same order that they came + // in. + private[this] var lastResponseSent: Future[Unit] = Future.unit + + // We keep track of the cancellation tokens for all the requests in flight. This gives us + // observability into the number of requests in flight and the ability to cancel them all + // if the connection gets closed. + private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]] + private val logger = Logger[NettyServerHandler[F]] private val EntityTooLarge: FullHttpResponse = { @@ -40,33 +75,73 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) res } + override def handlerAdded(ctx: ChannelHandlerContext): Unit = + if (ctx.channel.isActive) { + initHandler(ctx) + } + override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx) + + private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { + if (eventLoopContext == null) { + // Initialize our ExecutionContext + eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) + + // When the channel closes we want to cancel any pending dispatches. + // Since the listener will be executed from the channels EventLoop everything is thread safe. + val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => + logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") + pendingResponses.foreach(_.apply()) + } + } + } + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { - def runRoute(req: HttpRequest) = { + def writeError500(req: HttpRequest, reason: Throwable): Unit = { + logger.error("Error while processing the request", reason) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res).closeIfNeeded(req) - ctx.writeAndFlush(res).closeIfNeeded(req) - me.unit(()) - } + } + + def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { + val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + } + pendingResponses.enqueue(cancellationSwitch) + lastResponseSent = lastResponseSent.flatMap { _ => + runningFuture.transform { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } finally { + val _ = releaseReq() + } + case Failure(NonFatal(ex)) => + try { + writeError500(req, ex) + Failure(ex) + } + finally { + val _ = releaseReq() + } + case Failure(fatalException) => Failure(fatalException) + }(eventLoopContext) + }(eventLoopContext) } if (HttpUtil.is100ContinueExpected(request)) { @@ -76,14 +151,9 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) request match { case full: FullHttpRequest => val req = full.retain() - unsafeRunAsync { () => - runRoute(req) - .ensure(me.eval(req.release())) - } // exceptions should be handled + runRoute(req, () => req.release()) case req: StreamedHttpRequest => - unsafeRunAsync { () => - runRoute(req) - } + runRoute(req) case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") } @@ -156,22 +226,22 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) - future.addListener(new ChannelFutureListener() { + val _ = future.addListener(new ChannelFutureListener() { override def operationComplete(future: ChannelFuture) = { if (!future.isSuccess()) { logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) } - ctx.close() + val _ = ctx.close() } }) } else { - ctx + val _ = ctx .writeAndFlush(EntityTooLarge.retainedDuplicate()) .addListener(new ChannelFutureListener() { override def operationComplete(future: ChannelFuture) = { if (!future.isSuccess()) { logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - ctx.close() + val _ = ctx.close() } } }) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 91e2499164..c642949bc6 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -4,15 +4,18 @@ import io.netty.channel._ import io.netty.channel.unix.DomainSocketAddress import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.{NettyConfig, Route} +import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} import zio.{RIO, Unsafe, ZIO} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import scala.concurrent.ExecutionContext.Implicits +import scala.concurrent.Future case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) @@ -55,6 +58,17 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioDomainSocketBinding(socket, stop) } + private def unsafeRunAsync( + runtime: zio.Runtime[R] + )(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { + val cancelable = Unsafe.unsafe(implicit u => + runtime.unsafe.runToFuture( + block() + ) + ) + (cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global)) + } + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for { runtime <- ZIO.runtime[R] routes <- ZIO.foreach(routes)(identity) @@ -67,7 +81,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: config, new NettyServerHandler[RIO[R, *]]( route, - (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())), + unsafeRunAsync(runtime), config.maxContentLength ), eventLoopGroup, diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index eeb8766e31..227d9223db 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -10,6 +10,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import sttp.tapir.ztapir.RIOMonadError import zio.Task +import zio.interop.catz._ import scala.concurrent.Future @@ -26,7 +27,8 @@ class NettyZioServerTest extends TestSuite with EitherValues { val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() + new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala new file mode 100644 index 0000000000..3d3435a343 --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -0,0 +1,55 @@ +package sttp.tapir.server.tests + +import cats.effect.IO +import cats.effect.kernel.Async +import cats.effect.syntax.all._ +import cats.syntax.all._ +import org.scalatest.matchers.should.Matchers._ +import sttp.client3._ +import sttp.monad.MonadError +import sttp.tapir._ +import sttp.tapir.tests._ + +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{Semaphore, TimeUnit} +import scala.concurrent.duration._ + +class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE])(implicit + m: MonadError[F], + async: Async[F] +) { + import createServerTest._ + + def tests(): List[Test] = List({ + val canceledSemaphore = new Semaphore(1) + val canceled: AtomicBoolean = new AtomicBoolean(false) + testServerLogic( + endpoint + .out(plainBody[String]) + .serverLogic { _ => + (m.eval(canceledSemaphore.acquire())) >> (async.sleep(15.seconds) >> pureResult("processing finished".asRight[Unit])) + .onCancel(m.eval(canceled.set(true)) >> m.eval(canceledSemaphore.release())) + }, + "Client cancelling request triggers cancellation on the server" + ) { (backend, baseUri) => + val resp: IO[_] = basicRequest.get(uri"$baseUri").readTimeout(300.millis).send(backend) + + resp + .map { case result => + fail(s"Expected cancellation, but received a result: $result") + } + .handleErrorWith { + case _: SttpClientException.TimeoutException => // expected, this is how we trigged client-side cancellation + IO( + assert( + canceledSemaphore.tryAcquire(30L, TimeUnit.SECONDS), + "Timeout when waiting for cancellation to be handled as expected" + ) + ) >> + IO(assert(canceled.get(), "Cancellation expected, but not registered!")) + case other => + IO(fail(s"TimeoutException expected, but got $other")) + } + } + }) +}