Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cleanup in Netty handler after a request timeout #4156

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .scalafix.conf
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OrganizeImports.groupedImports = AggressiveMerge
OrganizeImports.targetDialect = Scala3
OrganizeImports.targetDialect = Scala3
OrganizeImports.removeUnused = false
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
case _ => Seq("-Xmax-inlines", "64")
}
},
Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.Assertion:s",
Test / scalacOptions += "-Wconf:msg=unused value of type org.scalatest.compatible.Assertion:s",
evictionErrorLevel := Level.Info
)

Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Versions {
val helidon = "4.0.10"
val sttp = "3.10.1"
val sttpModel = "1.7.11"
val sttpShared = "1.3.22"
val sttpShared = "1.4.0"
val sttpApispec = "0.11.3"
val akkaHttp = "10.2.10"
val akkaStreams = "2.6.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeSuccess(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onDecodeSuccess(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -64,7 +64,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onSecurityFailure(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onSecurityFailure(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -83,7 +83,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeFailure(ctx)
else {
val responseWithMetrics: F[Option[ServerResponse[B]]] = for {
def responseWithMetrics: F[Option[ServerResponse[B]]] = for {
response <- endpointHandler.onDecodeFailure(ctx)
withMetrics <- response match {
case Some(response) =>
Expand Down Expand Up @@ -129,7 +129,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
}
}

private def handleResponseExceptions[T](r: F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
private def handleResponseExceptions[T](r: => F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
r.handleError { case ex: Exception => collectExceptionMetrics(e, ex) }

private def collectExceptionMetrics[T](e: AnyEndpoint, ex: Throwable)(implicit monad: MonadError[F]): F[T] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ case class NettyConfig(

def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f)

def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown = copy(gracefulShutdownTimeout = None)
def withGracefulShutdownTimeout(t: FiniteDuration): NettyConfig = copy(gracefulShutdownTimeout = Some(t))
def noGracefulShutdown: NettyConfig = copy(gracefulShutdownTimeout = None)

def serverHeader(h: String): NettyConfig = copy(serverHeader = Some(h))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ import scala.collection.mutable.{Queue => MutableQueue}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import scala.util.{Failure, Success}
import java.util.concurrent.TimeoutException
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

/** @param unsafeRunAsync
* Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() =>
Expand Down Expand Up @@ -109,11 +112,13 @@ class NettyServerHandler[F[_]](
evt match {
case e: IdleStateEvent =>
if (e.state() == IdleState.WRITER_IDLE) {
logger.error(s"Closing connection due to exceeded response timeout of ${config.requestTimeout}")
logger.error(
s"Closing connection due to exceeded response timeout of ${config.requestTimeout.map(_.toString).getOrElse("(not set)")}"
)
writeError503ThenClose(ctx)
}
if (e.state() == IdleState.ALL_IDLE) {
logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout}")
logger.debug(s"Closing connection due to exceeded idle timeout of ${config.idleTimeout.map(_.toString).getOrElse("(not set)")}")
val _ = ctx.close()
}
case other =>
Expand Down Expand Up @@ -147,30 +152,42 @@ class NettyServerHandler[F[_]](
pendingResponses.enqueue(cancellationSwitch)
lastResponseSent = lastResponseSent.flatMap { _ =>
runningFuture
.andThen { case _ =>
requestTimeoutHandler.foreach(ctx.pipeline().remove)
}(eventLoopContext)
.transform {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
}
case Failure(NonFatal(ex)) =>
try {
writeError500(req, ex)
Failure(ex)
} finally {
val _ = releaseReq()
.transform { result =>
try {
// #4131: the channel might be closed if the request timed out
// both timeout & response-ready events (i.e., comleting this future) are handled on the event loop's executor,
// so they won't be handled concurrently
if (ctx.channel().isOpen()) {
requestTimeoutHandler.foreach(ctx.pipeline().remove)
result match {
case Success(serverResponse) =>
pendingResponses.dequeue()
try {
handleResponse(ctx, req, serverResponse)
Success(())
} catch {
case NonFatal(ex) =>
writeError500(req, ex)
Failure(ex)
}
case Failure(NonFatal(ex)) =>
writeError500(req, ex)
Failure(ex)
case Failure(fatalException) => Failure(fatalException)
}
} else {
// pendingResponses is already dequeued because the channel is closed
result match {
case Success(serverResponse) =>
val e = new TimeoutException("Request timed out")
handleResponseAfterTimeout(ctx, serverResponse, e)
Failure(e)
case Failure(e) => Failure(e)
}
}
case Failure(fatalException) => Failure(fatalException)
} finally {
val _ = releaseReq()
}
}(eventLoopContext)
}(eventLoopContext)
}
Expand Down Expand Up @@ -270,6 +287,39 @@ class NettyServerHandler[F[_]](
}
)

private def handleResponseAfterTimeout(
ctx: ChannelHandlerContext,
serverResponse: ServerResponse[NettyResponse],
timeoutException: Exception
): Unit =
serverResponse.handle(
ctx = ctx,
byteBufHandler = (channelPromise, byteBuf) => { val _ = channelPromise.setFailure(timeoutException) },
chunkedStreamHandler = (channelPromise, chunkedStream) => {
chunkedStream.close()
val _ = channelPromise.setFailure(timeoutException)
},
chunkedFileHandler = (channelPromise, chunkedFile) => {
chunkedFile.close()
val _ = channelPromise.setFailure(timeoutException)
},
reactiveStreamHandler = (channelPromise, publisher) => {
publisher.subscribe(new Subscriber[HttpContent] {
override def onSubscribe(s: Subscription): Unit = {
s.cancel()
val _ = channelPromise.setFailure(timeoutException)
}
override def onNext(t: HttpContent): Unit = ()
override def onError(t: Throwable): Unit = ()
override def onComplete(): Unit = ()
})
},
wsHandler = (responseContent) => {
val _ = responseContent.channelPromise.setFailure(timeoutException)
},
noBodyHandler = () => ()
)

private def initWsPipeline(
ctx: ChannelHandlerContext,
r: ReactiveWebSocketProcessorNettyResponseContent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sttp.tapir.server.netty

import sttp.tapir._
import sttp.tapir.tests.Test
import scala.concurrent.Future
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration.DurationInt
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.metrics.Metric
import sttp.tapir.server.metrics.EndpointMetric
import io.netty.channel.EventLoopGroup
import cats.effect.IO
import cats.effect.kernel.Resource
import scala.concurrent.ExecutionContext
import sttp.client3._
import sttp.capabilities.fs2.Fs2Streams
import sttp.capabilities.WebSockets
import org.scalatest.matchers.should.Matchers._
import cats.effect.unsafe.implicits.global
import sttp.model.StatusCode

class NettyFutureRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets])(implicit
ec: ExecutionContext
) {
def tests(): List[Test] = List(
Test("properly update metrics when a request times out") {
val e = endpoint.post
.in(stringBody)
.out(stringBody)
.serverLogicSuccess[Future] { body =>
Thread.sleep(2000); Future.successful(body)
}

val activeRequests = new AtomicInteger()
val totalRequests = new AtomicInteger()
val customMetrics: List[Metric[Future, AtomicInteger]] = List(
Metric(
metric = activeRequests,
onRequest = (_, metric, me) =>
me.eval {
EndpointMetric()
.onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } }
.onResponseBody { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
.onException { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
}
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) => me.eval(EndpointMetric().onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } })
)
)

val config =
NettyConfig.default
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.noGracefulShutdown
.requestTimeout(1.second)
val options = NettyFutureServerOptions.customiseInterceptors
.metricsInterceptor(new MetricsRequestInterceptor[Future](customMetrics, Seq.empty))
.options
val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addEndpoints(List(e)).start()))

Resource
.make(bind)(server => IO.fromFuture(IO.delay(server.stop())))
.map(_.port)
.use { port =>
basicRequest.post(uri"http://localhost:$port").body("test").send(backend).map { response =>
response.body should matchPattern { case Left(_) => }
response.code shouldBe StatusCode.ServiceUnavailable
// the metrics will only be updated when the endpoint's logic completes, which is 1 second after receiving the timeout response
Thread.sleep(1100)
activeRequests.get() shouldBe 0
totalRequests.get() shouldBe 1
}
}
.unsafeToFuture()
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class NettyFutureServerTest extends TestSuite with EitherValues {

val tests =
new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++
new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests()
new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() ++
new NettyFutureRequestTimeoutTests(eventLoopGroup, backend).tests()

(tests, eventLoopGroup)
}) { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ object NettySyncServerOptions:
doLogWhenReceived = debugLog(_, None),
doLogWhenHandled = debugLog,
doLogAllDecodeFailures = debugLog,
doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex),
doLogExceptions = (msg: String, e: Throwable) =>
e match
// if server logic is interrupted (e.g. due to timeout), this isn't an error, but might still be useful for debugging,
// to know how far processing got
case _: InterruptedException => log.debug(msg, e)
case _ => log.error(msg, e),
noLog = ()
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package sttp.tapir.server.netty

import cats.effect.IO
import cats.effect.unsafe.implicits.global
import io.netty.channel.EventLoopGroup
import org.scalatest.matchers.should.Matchers.*
import ox.*
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3.*
import sttp.model.StatusCode
import sttp.tapir.*
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.metrics.{EndpointMetric, Metric}
import sttp.tapir.server.netty.sync.{NettySyncServer, NettySyncServerOptions}
import sttp.tapir.tests.Test

import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import org.slf4j.LoggerFactory

class NettySyncRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets]):
val logger = LoggerFactory.getLogger(getClass.getName)

def tests(): List[Test] = List(
Test("properly update metrics when a request times out") {
val e = endpoint.post
.in(stringBody)
.out(stringBody)
.serverLogicSuccess[Identity]: body =>
Thread.sleep(2000)
body

val activeRequests = new AtomicInteger()
val totalRequests = new AtomicInteger()
val customMetrics: List[Metric[Identity, AtomicInteger]] = List(
Metric(
metric = activeRequests,
onRequest = (_, metric, me) =>
me.eval:
EndpointMetric()
.onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
(): Identity[Unit]
.onResponseBody: (_, _) =>
val _ = metric.decrementAndGet();
.onException: (_, _) =>
val _ = metric.decrementAndGet();
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) =>
me.eval(EndpointMetric().onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
)
)
)

val config =
NettyConfig.default
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.noGracefulShutdown
.requestTimeout(1.second)
val options = NettySyncServerOptions.customiseInterceptors
.metricsInterceptor(new MetricsRequestInterceptor[Identity](customMetrics, Seq.empty))
.options

Future.successful:
supervised:
val port = useInScope(NettySyncServer(options, config).addEndpoint(e).start())(_.stop()).port
basicRequest
.post(uri"http://localhost:$port")
.body("test")
.send(backend)
.map: response =>
response.body should matchPattern { case Left(_) => }
response.code shouldBe StatusCode.ServiceUnavailable
// unlike in NettyFutureRequestTimeoutTest, here interruption works properly, and the metrics should be updated quickly
Thread.sleep(100)
activeRequests.get() shouldBe 0
totalRequests.get() shouldBe 1
.unsafeRunSync()
}
)
Loading
Loading