From b26578417a8a6c6c1f6c7d4244ccbe6805a82f0a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 18 Oct 2023 17:17:31 +0200 Subject: [PATCH 01/19] Support for server-side request cancellation (CE3, Future) --- .../server/netty/cats/NettyCatsServer.scala | 6 +- .../server/netty/NettyFutureServer.scala | 6 +- .../netty/internal/NettyServerHandler.scala | 59 +++++++++++++++++-- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 604a13e57d..e6fe404d15 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -16,6 +16,7 @@ import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID import sttp.capabilities.fs2.Fs2Streams +import scala.concurrent.Future case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) @@ -53,6 +54,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyCatsDomainSocketBinding(socket, stop) } + private def unsafeRunAsync(block: () => F[Unit]): () => Future[Unit] = + options.dispatcher.unsafeRunCancelable(block()) + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): F[(SA, () => F[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[F] = new CatsMonadError[F]() @@ -61,7 +65,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index b53483a0a9..a343758a98 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -49,6 +49,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyFutureDomainSocketBinding(socket, stop) } + private def unsafeRunAsync(block: () => Future[Unit]): () => Future[Unit] = + block() + () => Future.unit // noop cancellation handler, we can't cancel native Futures + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() implicit val monadError: MonadError[Future] = new FutureMonad() @@ -57,7 +61,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), + new NettyServerHandler(route, (f: () => Future[Unit]) => unsafeRunAsync(f), config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 166f1da01f..ade1e5bb46 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -20,11 +20,18 @@ import sttp.tapir.server.netty.NettyResponseContent.{ import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import scala.collection.JavaConverters._ +import scala.collection.mutable.{Queue => MutableQueue} +import scala.concurrent.Future -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit - me: MonadError[F] +class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => (() => Future[Unit]), maxContentLength: Option[Int])( + implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { + // We keep track of the cancellation tokens for all the requests in flight. This gives us + // observability into the number of requests in flight and the ability to cancel them all + // if the connection gets closed. + private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]] + private val logger = Logger[NettyServerHandler[F]] private val EntityTooLarge: FullHttpResponse = { @@ -40,6 +47,44 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) res } + override def handlerAdded(ctx: ChannelHandlerContext): Unit = + if (ctx.channel.isActive) { + initHandler(ctx) + } + override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx) + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + logger.trace(s"channelReadComplete: ctx = $ctx") + // The normal response to read complete is to issue another read, + // but we only want to do that if there are no requests in flight, + // this will effectively limit the number of in flight requests that + // we'll handle by pushing back on the TCP stream, but it also ensures + // we don't get in the way of the request body reactive streams, + // which will be using channel read complete and read to implement + // their own back pressure + if (pendingResponses.isEmpty) { + ctx.read() + } else { + // otherwise forward it, so that any handler publishers downstream + // can handle it + ctx.fireChannelReadComplete() + } + () + } + + private[this] def initHandler(ctx: ChannelHandlerContext): Unit = + // When the channel closes we want to cancel any pending dispatches. + // Since the listener will be executed from the channels EventLoop everything is thread safe. + ctx.channel.closeFuture.addListener { (_: ChannelFuture) => + logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") + pendingResponses.foreach(_.apply()) + } + + // AUTO_READ is off, so need to do the first read explicitly. + // this method is called when the channel is registered with the event loop, + // so ctx.read is automatically safe here w/o needing an isRegistered(). + val _ = ctx.read() + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { def runRoute(req: HttpRequest) = { @@ -76,10 +121,16 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) request match { case full: FullHttpRequest => val req = full.retain() - unsafeRunAsync { () => + val cancellationSwitch: () => Future[Unit] = unsafeRunAsync { () => runRoute(req) - .ensure(me.eval(req.release())) + .ensure { + me.eval { + pendingResponses.dequeue() + req.release() + } + } } // exceptions should be handled + pendingResponses.enqueue(cancellationSwitch) case req: StreamedHttpRequest => unsafeRunAsync { () => runRoute(req) From a7a5669033a18cdb98ce80e96ab1f59f533d1924 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 19 Oct 2023 10:49:29 +0200 Subject: [PATCH 02/19] Add cancellation support to Netty ZIO --- .../sttp/tapir/server/netty/zio/NettyZioServer.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 91e2499164..4a77c7f717 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -13,6 +13,8 @@ import zio.{RIO, Unsafe, ZIO} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) @@ -55,6 +57,10 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioDomainSocketBinding(socket, stop) } + private def unsafeRunAsync(runtime: zio.Runtime[R])(block: () => RIO[R, Unit]): () => Future[Unit] = + val cancelable = Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(block())) + () => cancelable.cancel().map(_ => ())(Implicits.global) + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for { runtime <- ZIO.runtime[R] routes <- ZIO.foreach(routes)(identity) @@ -67,7 +73,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: config, new NettyServerHandler[RIO[R, *]]( route, - (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())), + unsafeRunAsync(runtime), config.maxContentLength ), eventLoopGroup, From d77a527c3e85896cbb3523c56c774413b750456b Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 19 Oct 2023 20:40:15 +0200 Subject: [PATCH 03/19] Adjust to Scala 2 --- .../sttp/tapir/server/netty/NettyFutureServer.scala | 3 ++- .../server/netty/internal/NettyServerHandler.scala | 10 +++------- .../sttp/tapir/server/netty/zio/NettyZioServer.scala | 3 ++- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index a343758a98..4c08ac2c3a 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -49,9 +49,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyFutureDomainSocketBinding(socket, stop) } - private def unsafeRunAsync(block: () => Future[Unit]): () => Future[Unit] = + private def unsafeRunAsync(block: () => Future[Unit]): () => Future[Unit] = { block() () => Future.unit // noop cancellation handler, we can't cancel native Futures + } private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index ade1e5bb46..3d784a39a3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -72,18 +72,14 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) () } - private[this] def initHandler(ctx: ChannelHandlerContext): Unit = + private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { // When the channel closes we want to cancel any pending dispatches. // Since the listener will be executed from the channels EventLoop everything is thread safe. - ctx.channel.closeFuture.addListener { (_: ChannelFuture) => + val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") pendingResponses.foreach(_.apply()) } - - // AUTO_READ is off, so need to do the first read explicitly. - // this method is called when the channel is registered with the event loop, - // so ctx.read is automatically safe here w/o needing an isRegistered(). - val _ = ctx.read() + } override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 4a77c7f717..33642dddfc 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -57,9 +57,10 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioDomainSocketBinding(socket, stop) } - private def unsafeRunAsync(runtime: zio.Runtime[R])(block: () => RIO[R, Unit]): () => Future[Unit] = + private def unsafeRunAsync(runtime: zio.Runtime[R])(block: () => RIO[R, Unit]): () => Future[Unit] = { val cancelable = Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(block())) () => cancelable.cancel().map(_ => ())(Implicits.global) + } private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for { runtime <- ZIO.runtime[R] From 64ff1667bc3312599021c91a7da0f084a51d5021 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 23 Oct 2023 16:47:55 +0200 Subject: [PATCH 04/19] Fix handling of correct request/response ordering --- .../server/netty/cats/NettyCatsServer.scala | 6 +- .../server/netty/NettyFutureServer.scala | 10 +- .../netty/internal/NettyServerHandler.scala | 141 +++++++++++++----- .../server/netty/zio/NettyZioServer.scala | 8 +- 4 files changed, 121 insertions(+), 44 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index e6fe404d15..1c361ab796 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -17,6 +17,8 @@ import java.nio.file.{Path, Paths} import java.util.UUID import sttp.capabilities.fs2.Fs2Streams import scala.concurrent.Future +import sttp.tapir.server.model.ServerResponse +import sttp.tapir.server.netty.NettyResponse case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) @@ -54,8 +56,8 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyCatsDomainSocketBinding(socket, stop) } - private def unsafeRunAsync(block: () => F[Unit]): () => Future[Unit] = - options.dispatcher.unsafeRunCancelable(block()) + private def unsafeRunAsync(block: () => F[ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = + options.dispatcher.unsafeToFutureCancelable(block()) private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): F[(SA, () => F[Unit])] = { val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 4c08ac2c3a..d92c704b02 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -11,6 +11,7 @@ import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID import scala.concurrent.{ExecutionContext, Future} +import sttp.tapir.server.model.ServerResponse case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext @@ -49,9 +50,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyFutureDomainSocketBinding(socket, stop) } - private def unsafeRunAsync(block: () => Future[Unit]): () => Future[Unit] = { - block() - () => Future.unit // noop cancellation handler, we can't cancel native Futures + private def unsafeRunAsync( + block: () => Future[ServerResponse[NettyResponse]] + ): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { + (block(), () => Future.unit) // noop cancellation handler, we can't cancel native Futures } private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): Future[(SA, () => Future[Unit])] = { @@ -62,7 +64,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => Future[Unit]) => unsafeRunAsync(f), config.maxContentLength), + new NettyServerHandler(route, unsafeRunAsync, config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 3d784a39a3..fa5ce16d43 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -22,11 +22,37 @@ import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import scala.collection.JavaConverters._ import scala.collection.mutable.{Queue => MutableQueue} import scala.concurrent.Future - -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => (() => Future[Unit]), maxContentLength: Option[Int])( - implicit me: MonadError[F] +import scala.util.Failure +import scala.util.Success +import scala.util.control.NonFatal +import scala.concurrent.ExecutionContext + +/** @param unsafeRunAsync + * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => + * Future[Unit]` allowing cancellation of that Future. For example, this can be realized by + * `cats.effect.std.Dispatcher.unsafeToFutureCancelable`. + */ +class NettyServerHandler[F[_]]( + route: Route[F], + unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), + maxContentLength: Option[Int] +)(implicit + me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { + // By using the Netty event loop assigned to this channel we get two benefits: + // 1. We can avoid the necessary hopping around of threads since Netty pipelines will + // only pass events up and down from within the event loop to which it is assigned. + // That means calls to ctx.read(), and ct.write(..), would have to be trampolined otherwise. + // 2. We get serialization of execution: the EventLoop is a serial execution queue so + // we can rest easy knowing that no two events will be executed in parallel. + private[this] var eventLoopContext: ExecutionContext = _ + + // This is used essentially as a queue, each incoming request attaches callbacks to this + // and replaces it to ensure that responses are written out in the same order that they came + // in. + private[this] var lastResponseSent: Future[Unit] = Future.unit + // We keep track of the cancellation tokens for all the requests in flight. This gives us // observability into the number of requests in flight and the ability to cancel them all // if the connection gets closed. @@ -73,64 +99,107 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { - // When the channel closes we want to cancel any pending dispatches. - // Since the listener will be executed from the channels EventLoop everything is thread safe. - val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => - logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") - pendingResponses.foreach(_.apply()) + if (eventLoopContext == null) { + // Initialize our ExecutionContext + eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) + + // When the channel closes we want to cancel any pending dispatches. + // Since the listener will be executed from the channels EventLoop everything is thread safe. + val _ = ctx.channel.closeFuture.addListener { (_: ChannelFuture) => + logger.debug(s"Http channel to ${ctx.channel.remoteAddress} closed. Cancelling ${pendingResponses.length} responses.") + pendingResponses.foreach(_.apply()) + } } } override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { - def runRoute(req: HttpRequest) = { + def runRoute(req: HttpRequest): F[ServerResponse[NettyResponse]] = { route(NettyServerRequest(req)) .map { case Some(response) => response case None => ServerResponse.notFound } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - me.unit(()) - } } if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () } else { - request match { + request match { // TODO refactor repetitions case full: FullHttpRequest => val req = full.retain() - val cancellationSwitch: () => Future[Unit] = unsafeRunAsync { () => + val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => runRoute(req) - .ensure { - me.eval { - pendingResponses.dequeue() + } + pendingResponses.enqueue(cancellationSwitch) + lastResponseSent = lastResponseSent.flatMap[Unit] { _ => + runningFuture.transform { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) req.release() + Success(()) + } catch { + case NonFatal(ex) => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + Failure(ex) } - } - } // exceptions should be handled - pendingResponses.enqueue(cancellationSwitch) + case Failure(NonFatal(ex)) => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + Failure(ex) + case Failure(ex) => Failure(ex) + }(eventLoopContext) + }(eventLoopContext) case req: StreamedHttpRequest => - unsafeRunAsync { () => + val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => runRoute(req) } + pendingResponses.enqueue(cancellationSwitch) + lastResponseSent = lastResponseSent.flatMap[Unit] { _ => + runningFuture.transform { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) + Success(()) + } catch { + case NonFatal(ex) => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + Failure(ex) + } + case Failure(NonFatal(ex)) => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + Failure(ex) + case Failure(ex) => Failure(ex) + }(eventLoopContext) + }(eventLoopContext) case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") } diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 33642dddfc..10d8915cd0 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -15,6 +15,8 @@ import java.nio.file.{Path, Paths} import java.util.UUID import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits +import sttp.tapir.server.model.ServerResponse +import sttp.tapir.server.netty.NettyResponse case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) @@ -57,9 +59,11 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioDomainSocketBinding(socket, stop) } - private def unsafeRunAsync(runtime: zio.Runtime[R])(block: () => RIO[R, Unit]): () => Future[Unit] = { + private def unsafeRunAsync( + runtime: zio.Runtime[R] + )(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { val cancelable = Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(block())) - () => cancelable.cancel().map(_ => ())(Implicits.global) + (cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global)) } private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): RIO[R, (SA, () => RIO[R, Unit])] = for { From c1ef9b422c3f273032ee17f48898c7047dbda25a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 23 Oct 2023 16:54:46 +0200 Subject: [PATCH 05/19] Remove custom implementation of channelReadComplete --- .../netty/internal/NettyServerHandler.scala | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index fa5ce16d43..6fed4ff05d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -79,25 +79,6 @@ class NettyServerHandler[F[_]]( } override def channelActive(ctx: ChannelHandlerContext): Unit = initHandler(ctx) - override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { - logger.trace(s"channelReadComplete: ctx = $ctx") - // The normal response to read complete is to issue another read, - // but we only want to do that if there are no requests in flight, - // this will effectively limit the number of in flight requests that - // we'll handle by pushing back on the TCP stream, but it also ensures - // we don't get in the way of the request body reactive streams, - // which will be using channel read complete and read to implement - // their own back pressure - if (pendingResponses.isEmpty) { - ctx.read() - } else { - // otherwise forward it, so that any handler publishers downstream - // can handle it - ctx.fireChannelReadComplete() - } - () - } - private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { if (eventLoopContext == null) { // Initialize our ExecutionContext From 239068e38b4d8fb7194a3eba821643f52d022fd4 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 24 Oct 2023 08:20:22 +0200 Subject: [PATCH 06/19] Organize imports --- .../sttp/tapir/server/netty/cats/NettyCatsServer.scala | 7 +++---- .../scala/sttp/tapir/server/netty/zio/NettyZioServer.scala | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 1c361ab796..7e7b3c8fea 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -5,20 +5,19 @@ import cats.effect.{Async, IO, Resource} import cats.syntax.all._ import io.netty.channel._ import io.netty.channel.unix.DomainSocketAddress +import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} -import sttp.tapir.server.netty.{NettyConfig, Route} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID -import sttp.capabilities.fs2.Fs2Streams import scala.concurrent.Future -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.NettyResponse case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 10d8915cd0..f36e9f4d45 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -4,19 +4,18 @@ import io.netty.channel._ import io.netty.channel.unix.DomainSocketAddress import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.{NettyConfig, Route} +import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} import zio.{RIO, Unsafe, ZIO} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID -import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.NettyResponse +import scala.concurrent.Future case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyZioServerOptions[R], config: NettyConfig) { def addEndpoint(se: ZServerEndpoint[R, ZioStreams]): NettyZioServer[R] = addEndpoints(List(se)) From ef26db48f12eb21d57971980b70224fa31759575 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 24 Oct 2023 16:25:49 +0200 Subject: [PATCH 07/19] Add tests for cancellation (CE3) --- .../netty/cats/NettyCatsServerTest.scala | 4 +- .../tests/ServerCancellationTests.scala | 47 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 84e90e1266..e54967ad6a 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -31,7 +31,9 @@ class NettyCatsServerTest extends TestSuite with EitherValues { multipart = false, maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) ) - .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() + .tests() ++ + new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ + new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala new file mode 100644 index 0000000000..fb3635057f --- /dev/null +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -0,0 +1,47 @@ +package sttp.tapir.server.tests + +import cats.effect.IO +import cats.effect.kernel.Async +import cats.effect.syntax.all._ +import cats.syntax.all._ +import org.scalatest +import org.scalatest.compatible.Assertion +import org.scalatest.matchers.should.Matchers._ +import sttp.client3._ +import sttp.monad.MonadError +import sttp.tapir._ +import sttp.tapir.tests._ + +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.duration._ + +class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE])(implicit + m: MonadError[F], + async: Async[F] +) { + import createServerTest._ + + def tests(): List[Test] = List({ + val canceled: AtomicBoolean = new AtomicBoolean(false) + testServerLogic( + endpoint + .out(plainBody[String]) + .serverLogic { _ => + (async.sleep(15.seconds) >> pureResult("processing finished".asRight)).onCancel(m.eval(canceled.set(true))) + }, + "Client cancelling request triggers cancellation on the server" + ) { (backend, baseUri) => + val resp: IO[Assertion] = basicRequest.get(uri"$baseUri").send(backend).timeout(300.millis).map { case result => + fail(s"Expected cancellation, but received a result: $result") + } + + resp.handleErrorWith { + case _: TimeoutException => + IO(assert(canceled.get(), "Cancellation expected, but not registered!")) + case other => + IO(fail(s"TimeoutException expected, but got $other")) + } + } + }) +} From 8b047e50b8beb0df87c88d21151894c76d396e89 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 24 Oct 2023 16:26:16 +0200 Subject: [PATCH 08/19] wip --- .../sttp/tapir/server/netty/zio/NettyZioServerTest.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala index eeb8766e31..227d9223db 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioServerTest.scala @@ -10,6 +10,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} import sttp.tapir.ztapir.RIOMonadError import zio.Task +import zio.interop.catz._ import scala.concurrent.Future @@ -26,7 +27,8 @@ class NettyZioServerTest extends TestSuite with EitherValues { val tests = new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false).tests() ++ - new ServerStreamingTests(createServerTest, ZioStreams).tests() + new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ + new ServerCancellationTests(createServerTest)(monadError, asyncInstance).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => From 120a2da2689e1575e3d445de70e9460aaa28fb54 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 25 Oct 2023 13:06:16 +0200 Subject: [PATCH 09/19] Fix cancellation test --- .../server/tests/ServerCancellationTests.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala index fb3635057f..26785e7d94 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -12,8 +12,8 @@ import sttp.monad.MonadError import sttp.tapir._ import sttp.tapir.tests._ -import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{Semaphore, TimeoutException} import scala.concurrent.duration._ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE])(implicit @@ -23,12 +23,14 @@ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServ import createServerTest._ def tests(): List[Test] = List({ + val canceledSemaphore = new Semaphore(1) val canceled: AtomicBoolean = new AtomicBoolean(false) testServerLogic( endpoint .out(plainBody[String]) .serverLogic { _ => - (async.sleep(15.seconds) >> pureResult("processing finished".asRight)).onCancel(m.eval(canceled.set(true))) + (m.eval(canceledSemaphore.acquire())) >> (async.sleep(15.seconds) >> pureResult("processing finished".asRight)) + .onCancel(m.eval(canceled.set(true)) >> m.eval(canceledSemaphore.release())) }, "Client cancelling request triggers cancellation on the server" ) { (backend, baseUri) => @@ -37,8 +39,11 @@ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServ } resp.handleErrorWith { - case _: TimeoutException => - IO(assert(canceled.get(), "Cancellation expected, but not registered!")) + case _: TimeoutException => // expected, this is how we trigged client-side cancellation + IO(canceledSemaphore.acquire()) + .timeout(3.seconds) + .handleError(_ => IO.eval(fail("Timeout when waiting for cancellation to be handled as expected"))) >> + IO(assert(canceled.get(), "Cancellation expected, but not registered!")) case other => IO(fail(s"TimeoutException expected, but got $other")) } From 94d2b3fd4ed09c5ad7018ac1086b6e88a8d01797 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 25 Oct 2023 14:04:25 +0200 Subject: [PATCH 10/19] Remove problematic redundant imports --- .../scala/sttp/tapir/server/netty/zio/NettyZioServer.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index f36e9f4d45..c642949bc6 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -61,7 +61,11 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: private def unsafeRunAsync( runtime: zio.Runtime[R] )(block: () => RIO[R, ServerResponse[NettyResponse]]): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { - val cancelable = Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(block())) + val cancelable = Unsafe.unsafe(implicit u => + runtime.unsafe.runToFuture( + block() + ) + ) (cancelable, () => cancelable.cancel().map(_ => ())(Implicits.global)) } From c877bc67d2ea616e4b9efc61c44f11b57624f877 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 25 Oct 2023 17:16:38 +0200 Subject: [PATCH 11/19] More refactoring and test tweaks --- .../netty/cats/NettyCatsServerTest.scala | 5 +- .../NettyFs2StreamingCancellationTest.scala | 54 ++++++++ .../netty/internal/NettyServerHandler.scala | 125 +++++++----------- .../tests/ServerCancellationTests.scala | 35 ++--- 4 files changed, 121 insertions(+), 98 deletions(-) create mode 100644 server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index e54967ad6a..7b3f2a1303 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -31,9 +31,10 @@ class NettyCatsServerTest extends TestSuite with EitherValues { multipart = false, maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) ) - .tests() ++ + .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() ++ - new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() + new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ + new NettyFs2StreamingCancellationTest(createServerTest).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala new file mode 100644 index 0000000000..e88812f476 --- /dev/null +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala @@ -0,0 +1,54 @@ +package sttp.tapir.server.netty.cats + +import cats.effect.IO +import cats.syntax.all._ +import org.scalatest.matchers.should.Matchers._ +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.monad.MonadError +import sttp.tapir.integ.cats.effect.CatsMonadError +import sttp.tapir.server.tests.{CreateServerTest, _} +import sttp.tapir.tests._ +import sttp.tapir.{CodecFormat, _} + +import java.nio.charset.StandardCharsets +import scala.concurrent.duration._ +import cats.effect.std.Queue +import cats.effect.unsafe.implicits.global + +class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: CreateServerTest[IO, Fs2Streams[IO], OPTIONS, ROUTE]) { + import createServerTest._ + + implicit val m: MonadError[IO] = new CatsMonadError[IO]() + def tests(): List[Test] = List({ + val buffer = Queue.unbounded[IO, Byte].unsafeRunSync() + val body_20_slowly_emitted_bytes = + fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(200) + testServer( + endpoint.get + .in("streamCanceled") + .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))), + "Client cancelling streaming triggers cancellation on the server" + )(_ => pureResult(body_20_slowly_emitted_bytes.asRight[Unit])) { (backend, baseUri) => + + val expectedMaxAccumulated = 3 + + basicRequest + .get(uri"$baseUri/streamCanceled") + .send(backend) + .timeout(300.millis) + .attempt >> + IO.sleep(600.millis) + .flatMap(r => + buffer.size.flatMap(accumulated => + IO( + assert( + accumulated <= expectedMaxAccumulated, + s"Buffer accumulated $accumulated elements. Expected < $expectedMaxAccumulated due to cancellation." + ) + ) + ) + ) + } + }) +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 6fed4ff05d..d760839841 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -95,92 +95,57 @@ class NettyServerHandler[F[_]]( override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { - def runRoute(req: HttpRequest): F[ServerResponse[NettyResponse]] = { + def writeError500(req: HttpRequest, reason: Throwable): Unit = { + logger.error("Error while processing the request", reason) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } + ctx.writeAndFlush(res).closeIfNeeded(req) + + } + + def runRoute(req: HttpRequest, releaseReq: () => Any = () => ()): Unit = { + val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + } + pendingResponses.enqueue(cancellationSwitch) + lastResponseSent = lastResponseSent.flatMap[Unit] { _ => + runningFuture.transform { + case Success(serverResponse) => + pendingResponses.dequeue() + try { + handleResponse(ctx, req, serverResponse) + releaseReq() + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } + case Failure(NonFatal(ex)) => + writeError500(req, ex) + Failure(ex) + case Failure(ex) => Failure(ex) + }(eventLoopContext) + }(eventLoopContext) } if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () } else { - request match { // TODO refactor repetitions + request match { case full: FullHttpRequest => val req = full.retain() - val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => - runRoute(req) - } - pendingResponses.enqueue(cancellationSwitch) - lastResponseSent = lastResponseSent.flatMap[Unit] { _ => - runningFuture.transform { - case Success(serverResponse) => - pendingResponses.dequeue() - try { - handleResponse(ctx, req, serverResponse) - req.release() - Success(()) - } catch { - case NonFatal(ex) => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - Failure(ex) - } - case Failure(NonFatal(ex)) => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - Failure(ex) - case Failure(ex) => Failure(ex) - }(eventLoopContext) - }(eventLoopContext) + runRoute(req, () => req.release()) case req: StreamedHttpRequest => - val (runningFuture, cancellationSwitch) = unsafeRunAsync { () => - runRoute(req) - } - pendingResponses.enqueue(cancellationSwitch) - lastResponseSent = lastResponseSent.flatMap[Unit] { _ => - runningFuture.transform { - case Success(serverResponse) => - pendingResponses.dequeue() - try { - handleResponse(ctx, req, serverResponse) - Success(()) - } catch { - case NonFatal(ex) => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - Failure(ex) - } - case Failure(NonFatal(ex)) => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - Failure(ex) - case Failure(ex) => Failure(ex) - }(eventLoopContext) - }(eventLoopContext) + runRoute(req) case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") } @@ -253,22 +218,22 @@ class NettyServerHandler[F[_]]( if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) - future.addListener(new ChannelFutureListener() { + val _ = future.addListener(new ChannelFutureListener() { override def operationComplete(future: ChannelFuture) = { if (!future.isSuccess()) { logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) } - ctx.close() + val _ = ctx.close() } }) } else { - ctx + val _ = ctx .writeAndFlush(EntityTooLarge.retainedDuplicate()) .addListener(new ChannelFutureListener() { override def operationComplete(future: ChannelFuture) = { if (!future.isSuccess()) { logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) - ctx.close() + val _ = ctx.close() } } }) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala index 26785e7d94..1f95c89822 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -4,8 +4,6 @@ import cats.effect.IO import cats.effect.kernel.Async import cats.effect.syntax.all._ import cats.syntax.all._ -import org.scalatest -import org.scalatest.compatible.Assertion import org.scalatest.matchers.should.Matchers._ import sttp.client3._ import sttp.monad.MonadError @@ -13,7 +11,7 @@ import sttp.tapir._ import sttp.tapir.tests._ import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.{Semaphore, TimeoutException} +import java.util.concurrent.{Semaphore, TimeUnit, TimeoutException} import scala.concurrent.duration._ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE])(implicit @@ -29,24 +27,29 @@ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServ endpoint .out(plainBody[String]) .serverLogic { _ => - (m.eval(canceledSemaphore.acquire())) >> (async.sleep(15.seconds) >> pureResult("processing finished".asRight)) + (m.eval(canceledSemaphore.acquire())) >> (async.sleep(15.seconds) >> pureResult("processing finished".asRight[Unit])) .onCancel(m.eval(canceled.set(true)) >> m.eval(canceledSemaphore.release())) }, "Client cancelling request triggers cancellation on the server" ) { (backend, baseUri) => - val resp: IO[Assertion] = basicRequest.get(uri"$baseUri").send(backend).timeout(300.millis).map { case result => - fail(s"Expected cancellation, but received a result: $result") - } + val resp: IO[_] = basicRequest.get(uri"$baseUri").send(backend).timeout(300.millis) - resp.handleErrorWith { - case _: TimeoutException => // expected, this is how we trigged client-side cancellation - IO(canceledSemaphore.acquire()) - .timeout(3.seconds) - .handleError(_ => IO.eval(fail("Timeout when waiting for cancellation to be handled as expected"))) >> - IO(assert(canceled.get(), "Cancellation expected, but not registered!")) - case other => - IO(fail(s"TimeoutException expected, but got $other")) - } + resp + .map { case result => + fail(s"Expected cancellation, but received a result: $result") + } + .handleErrorWith { + case _: TimeoutException => // expected, this is how we trigged client-side cancellation + IO( + assert( + canceledSemaphore.tryAcquire(3L, TimeUnit.SECONDS), + "Timeout when waiting for cancellation to be handled as expected" + ) + ) >> + IO(assert(canceled.get(), "Cancellation expected, but not registered!")) + case other => + IO(fail(s"TimeoutException expected, but got $other")) + } } }) } From 68b8684ae55fd9d69189127d99409df261105cc0 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 25 Oct 2023 17:20:46 +0200 Subject: [PATCH 12/19] Fix warning for unused param --- .../server/netty/cats/NettyFs2StreamingCancellationTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala index e88812f476..1eb0412787 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala @@ -39,7 +39,7 @@ class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: Create .timeout(300.millis) .attempt >> IO.sleep(600.millis) - .flatMap(r => + .flatMap(_ => buffer.size.flatMap(accumulated => IO( assert( From 0c7956f4e1ab31f1a2a1068909cd02e227a51c23 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 25 Oct 2023 21:39:31 +0200 Subject: [PATCH 13/19] Extend timeout when waiting for semaphore - Might be needed for slow CI --- .../scala/sttp/tapir/server/tests/ServerCancellationTests.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala index 1f95c89822..10ea7c3884 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -42,7 +42,7 @@ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServ case _: TimeoutException => // expected, this is how we trigged client-side cancellation IO( assert( - canceledSemaphore.tryAcquire(3L, TimeUnit.SECONDS), + canceledSemaphore.tryAcquire(30L, TimeUnit.SECONDS), "Timeout when waiting for cancellation to be handled as expected" ) ) >> From 0bd6f49a52b8bedfb4c0d705f5037a514d3bfbac Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 26 Oct 2023 10:13:36 +0200 Subject: [PATCH 14/19] Keep ordering responses even if there were failures --- flake.nix | 32 +++++++++++++++++++ .../netty/internal/NettyServerHandler.scala | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 flake.nix diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000..77bdda12f1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,32 @@ +{ + description = "JDK 11 shell flake"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + + mach-nix.url = "github:davhau/mach-nix"; + }; + + outputs = { self, nixpkgs, mach-nix, flake-utils, ... }: + let + pythonVersion = "python37"; + in + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + mach = mach-nix.lib.${system}; + + pythonEnv = mach.mkPython { + python = pythonVersion; + requirements = builtins.readFile ./requirements.txt; + }; + in + { + devShells.default = pkgs.mkShellNoCC { + packages = [ pkgs.jdk11 ]; + + }; + } + ); +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index d760839841..153a93c708 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -115,7 +115,7 @@ class NettyServerHandler[F[_]]( } } pendingResponses.enqueue(cancellationSwitch) - lastResponseSent = lastResponseSent.flatMap[Unit] { _ => + lastResponseSent = lastResponseSent.transformWith { _ => runningFuture.transform { case Success(serverResponse) => pendingResponses.dequeue() From ffa9b612ebd383fffb87f21f42978282c797f42a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 26 Oct 2023 10:22:14 +0200 Subject: [PATCH 15/19] Ensure that stream completes before client gives up --- .../server/netty/cats/NettyFs2StreamingCancellationTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala index 1eb0412787..df25c4fefc 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyFs2StreamingCancellationTest.scala @@ -23,7 +23,7 @@ class NettyFs2StreamingCancellationTest[OPTIONS, ROUTE](createServerTest: Create def tests(): List[Test] = List({ val buffer = Queue.unbounded[IO, Byte].unsafeRunSync() val body_20_slowly_emitted_bytes = - fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(200) + fs2.Stream.awakeEvery[IO](100.milliseconds).map(_ => 42.toByte).evalMap(b => { buffer.offer(b) >> IO.pure(b) }).take(100) testServer( endpoint.get .in("streamCanceled") From 883698c56bac7a0e59deea8241692df0627090e0 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 26 Oct 2023 10:22:39 +0200 Subject: [PATCH 16/19] Use readTimeout to ensure cancellation is triggered on JDK11 --- .../sttp/tapir/server/tests/ServerCancellationTests.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala index 10ea7c3884..3d3435a343 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerCancellationTests.scala @@ -11,7 +11,7 @@ import sttp.tapir._ import sttp.tapir.tests._ import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.{Semaphore, TimeUnit, TimeoutException} +import java.util.concurrent.{Semaphore, TimeUnit} import scala.concurrent.duration._ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE])(implicit @@ -32,14 +32,14 @@ class ServerCancellationTests[F[_], OPTIONS, ROUTE](createServerTest: CreateServ }, "Client cancelling request triggers cancellation on the server" ) { (backend, baseUri) => - val resp: IO[_] = basicRequest.get(uri"$baseUri").send(backend).timeout(300.millis) + val resp: IO[_] = basicRequest.get(uri"$baseUri").readTimeout(300.millis).send(backend) resp .map { case result => fail(s"Expected cancellation, but received a result: $result") } .handleErrorWith { - case _: TimeoutException => // expected, this is how we trigged client-side cancellation + case _: SttpClientException.TimeoutException => // expected, this is how we trigged client-side cancellation IO( assert( canceledSemaphore.tryAcquire(30L, TimeUnit.SECONDS), From 4790c8041da0e2a3189382c93c3cb82d824dc1cc Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 26 Oct 2023 10:46:35 +0200 Subject: [PATCH 17/19] Remove dev stuff --- flake.nix | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 flake.nix diff --git a/flake.nix b/flake.nix deleted file mode 100644 index 77bdda12f1..0000000000 --- a/flake.nix +++ /dev/null @@ -1,32 +0,0 @@ -{ - description = "JDK 11 shell flake"; - - inputs = { - nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; - flake-utils.url = "github:numtide/flake-utils"; - - mach-nix.url = "github:davhau/mach-nix"; - }; - - outputs = { self, nixpkgs, mach-nix, flake-utils, ... }: - let - pythonVersion = "python37"; - in - flake-utils.lib.eachDefaultSystem (system: - let - pkgs = nixpkgs.legacyPackages.${system}; - mach = mach-nix.lib.${system}; - - pythonEnv = mach.mkPython { - python = pythonVersion; - requirements = builtins.readFile ./requirements.txt; - }; - in - { - devShells.default = pkgs.mkShellNoCC { - packages = [ pkgs.jdk11 ]; - - }; - } - ); -} From 995e6e08b94c8c588d9c7c7aa304eeac00d98f70 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 26 Oct 2023 13:17:48 +0200 Subject: [PATCH 18/19] Add proper attribution to http4s source code --- .../sttp/tapir/server/netty/internal/NettyServerHandler.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 153a93c708..7da0226785 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -40,6 +40,8 @@ class NettyServerHandler[F[_]]( me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { + // Cancellation handling with eventLoopContext, lastResponseSent, and pendingResponses has been adapted + // from http4s: https://github.com/http4s/http4s-netty/pull/396/files // By using the Netty event loop assigned to this channel we get two benefits: // 1. We can avoid the necessary hopping around of threads since Netty pipelines will // only pass events up and down from within the event loop to which it is assigned. From bd8b94a8de6b0560ae2a1017c4160c2398d08706 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 31 Oct 2023 09:55:55 +0100 Subject: [PATCH 19/19] Ensure that request release is called on errors --- .../netty/internal/NettyServerHandler.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 7da0226785..051e15b061 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -117,23 +117,29 @@ class NettyServerHandler[F[_]]( } } pendingResponses.enqueue(cancellationSwitch) - lastResponseSent = lastResponseSent.transformWith { _ => + lastResponseSent = lastResponseSent.flatMap { _ => runningFuture.transform { case Success(serverResponse) => pendingResponses.dequeue() try { - handleResponse(ctx, req, serverResponse) - releaseReq() + handleResponse(ctx, req, serverResponse) Success(()) } catch { case NonFatal(ex) => - writeError500(req, ex) + writeError500(req, ex) Failure(ex) + } finally { + val _ = releaseReq() } case Failure(NonFatal(ex)) => - writeError500(req, ex) - Failure(ex) - case Failure(ex) => Failure(ex) + try { + writeError500(req, ex) + Failure(ex) + } + finally { + val _ = releaseReq() + } + case Failure(fatalException) => Failure(fatalException) }(eventLoopContext) }(eventLoopContext) }