Skip to content

Commit

Permalink
Fix post-cancellation cleanup in synchronous netty server
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Nov 12, 2024
1 parent b9b08a3 commit fde3511
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .scalafix.conf
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OrganizeImports.groupedImports = AggressiveMerge
OrganizeImports.targetDialect = Scala3
OrganizeImports.targetDialect = Scala3
OrganizeImports.removeUnused = false
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Versions {
val helidon = "4.0.10"
val sttp = "3.10.1"
val sttpModel = "1.7.11"
val sttpShared = "1.3.22"
val sttpShared = "1.4.0"
val sttpApispec = "0.11.3"
val akkaHttp = "10.2.10"
val akkaStreams = "2.6.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeSuccess(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onDecodeSuccess(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -64,7 +64,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onSecurityFailure(ctx)
else {
val responseWithMetrics: F[ServerResponse[B]] = for {
def responseWithMetrics: F[ServerResponse[B]] = for {
_ <- collectRequestMetrics(ctx.endpoint)
response <- endpointHandler.onSecurityFailure(ctx)
_ <- collectResponseHeadersMetrics(ctx.endpoint, response)
Expand All @@ -83,7 +83,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = {
if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeFailure(ctx)
else {
val responseWithMetrics: F[Option[ServerResponse[B]]] = for {
def responseWithMetrics: F[Option[ServerResponse[B]]] = for {
response <- endpointHandler.onDecodeFailure(ctx)
withMetrics <- response match {
case Some(response) =>
Expand Down Expand Up @@ -129,7 +129,7 @@ private[metrics] class MetricsEndpointInterceptor[F[_]](
}
}

private def handleResponseExceptions[T](r: F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
private def handleResponseExceptions[T](r: => F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] =
r.handleError { case ex: Exception => collectExceptionMetrics(e, ex) }

private def collectExceptionMetrics[T](e: AnyEndpoint, ex: Throwable)(implicit monad: MonadError[F]): F[T] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.scalatest.matchers.should.Matchers._
import cats.effect.unsafe.implicits.global
import sttp.model.StatusCode

class NettyRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets])(implicit
class NettyFutureRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets])(implicit
ec: ExecutionContext
) {
def tests(): List[Test] = List(
Expand All @@ -39,14 +39,14 @@ class NettyRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBack
onRequest = (_, metric, me) =>
me.eval {
EndpointMetric()
.onEndpointRequest { _ => me.eval(metric.incrementAndGet()) }
.onResponseBody { (_, _) => me.eval(metric.decrementAndGet()) }
.onException { (_, _) => me.eval(metric.decrementAndGet()) }
.onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } }
.onResponseBody { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
.onException { (_, _) => me.eval { val _ = metric.decrementAndGet(); } }
}
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) => me.eval(EndpointMetric().onEndpointRequest { _ => me.eval(metric.incrementAndGet()) })
onRequest = (_, metric, me) => me.eval(EndpointMetric().onEndpointRequest { _ => me.eval { val _ = metric.incrementAndGet(); } })
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NettyFutureServerTest extends TestSuite with EitherValues {
val tests =
new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++
new ServerGracefulShutdownTests(createServerTest, Sleeper.futureSleeper).tests() ++
new NettyRequestTimeoutTests(eventLoopGroup, backend).tests()
new NettyFutureRequestTimeoutTests(eventLoopGroup, backend).tests()

(tests, eventLoopGroup)
}) { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ object NettySyncServerOptions:
doLogWhenReceived = debugLog(_, None),
doLogWhenHandled = debugLog,
doLogAllDecodeFailures = debugLog,
doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex),
doLogExceptions = (msg: String, e: Throwable) =>
e match
// if server logic is interrupted (e.g. due to timeout), this isn't an error, but might still be useful for debugging,
// to know how far processing got
case _: InterruptedException => log.debug(msg, e)
case _ => log.error(msg, e),
noLog = ()
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package sttp.tapir.server.netty

import cats.effect.IO
import cats.effect.unsafe.implicits.global
import io.netty.channel.EventLoopGroup
import org.scalatest.matchers.should.Matchers.*
import ox.*
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3.*
import sttp.model.StatusCode
import sttp.tapir.*
import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor
import sttp.tapir.server.metrics.{EndpointMetric, Metric}
import sttp.tapir.server.netty.sync.{NettySyncServer, NettySyncServerOptions}
import sttp.tapir.tests.Test

import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import org.slf4j.LoggerFactory

class NettySyncRequestTimeoutTests(eventLoopGroup: EventLoopGroup, backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets]):
val logger = LoggerFactory.getLogger(getClass.getName)

def tests(): List[Test] = List(
Test("properly update metrics when a request times out") {
val e = endpoint.post
.in(stringBody)
.out(stringBody)
.serverLogicSuccess[Identity]: body =>
Thread.sleep(2000)
body

val activeRequests = new AtomicInteger()
val totalRequests = new AtomicInteger()
val customMetrics: List[Metric[Identity, AtomicInteger]] = List(
Metric(
metric = activeRequests,
onRequest = (_, metric, me) =>
me.eval:
EndpointMetric()
.onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
(): Identity[Unit]
.onResponseBody: (_, _) =>
val _ = metric.decrementAndGet();
.onException: (_, _) =>
val _ = metric.decrementAndGet();
),
Metric(
metric = totalRequests,
onRequest = (_, metric, me) =>
me.eval(EndpointMetric().onEndpointRequest: _ =>
val _ = metric.incrementAndGet();
)
)
)

val config =
NettyConfig.default
.eventLoopGroup(eventLoopGroup)
.randomPort
.withDontShutdownEventLoopGroupOnClose
.noGracefulShutdown
.requestTimeout(1.second)
val options = NettySyncServerOptions.customiseInterceptors
.metricsInterceptor(new MetricsRequestInterceptor[Identity](customMetrics, Seq.empty))
.options

Future.successful:
supervised:
val port = useInScope(NettySyncServer(options, config).addEndpoint(e).start())(_.stop()).port
basicRequest
.post(uri"http://localhost:$port")
.body("test")
.send(backend)
.map: response =>
response.body should matchPattern { case Left(_) => }
response.code shouldBe StatusCode.ServiceUnavailable
// unlike in NettyFutureRequestTimeoutTest, here interruption works properly, and the metrics should be updated quickly
Thread.sleep(100)
activeRequests.get() shouldBe 0
totalRequests.get() shouldBe 1
.unsafeRunSync()
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import ox.flow.Flow
import scala.annotation.nowarn
import sttp.tapir.server.netty.NettySyncRequestTimeoutTests

class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {

Expand All @@ -44,7 +45,8 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {
new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) {
override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = _.map(f)
override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Flow.empty
}.tests()
}.tests() ++
NettySyncRequestTimeoutTests(eventLoopGroup, backend).tests()

tests.foreach { t =>
if (testNameFilter.forall(filter => t.name.contains(filter))) {
Expand Down

0 comments on commit fde3511

Please sign in to comment.