From 110d1ecf8d0fa8d4e82667d83d20da3c42c0c400 Mon Sep 17 00:00:00 2001 From: Stephen Bly Date: Wed, 26 Oct 2022 16:42:46 -0400 Subject: [PATCH] Add ability to only count certain exceptions as failures --- .../circuit/CircuitBreaker.scala | 50 +++++++++++++------ .../circuit/CircuitBreakerTests.scala | 24 +++++++-- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/io/chrisdavenport/circuit/CircuitBreaker.scala b/core/src/main/scala/io/chrisdavenport/circuit/CircuitBreaker.scala index 743bcac..b40b7ff 100644 --- a/core/src/main/scala/io/chrisdavenport/circuit/CircuitBreaker.scala +++ b/core/src/main/scala/io/chrisdavenport/circuit/CircuitBreaker.scala @@ -247,14 +247,17 @@ object CircuitBreaker { * default implementations. * @param maxResetTimeout is the maximum timeout the circuit breaker * is allowed to use when applying the `backoff` result. + * @param exceptionFilter a predicate that returns true for exceptions which should trigger the circuitbreaker, + * and false for those which should not (ie be treated the same as success) */ def of[F[_]]( maxFailures: Int, resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration = Backoff.exponential, - maxResetTimeout: Duration = 1.minute + maxResetTimeout: Duration = 1.minute, + exceptionFilter: Throwable => Boolean = Function.const(true) )(implicit F: Temporal[F]): F[CircuitBreaker[F]] = { - of(maxFailures, resetTimeout, backoff, maxResetTimeout, F.unit, F. unit, F.unit, F.unit) + of(maxFailures, resetTimeout, backoff, maxResetTimeout, exceptionFilter, F.unit, F. unit, F.unit, F.unit) } /** @@ -275,14 +278,17 @@ object CircuitBreaker { * default implementations. * @param maxResetTimeout is the maximum timeout the circuit breaker * is allowed to use when applying the `backoff` + * @param exceptionFilter a predicate that returns true for exceptions which should trigger the circuitbreaker, + * and false for those which should not (ie be treated the same as success) */ def in[F[_], G[_]]( maxFailures: Int, resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration = Backoff.exponential, - maxResetTimeout: Duration = 1.minute + maxResetTimeout: Duration = 1.minute, + exceptionFilter: Throwable => Boolean = Function.const(true) )(implicit F: Sync[F], G: Async[G]): F[CircuitBreaker[G]] = { - in[F, G](maxFailures, resetTimeout, backoff, maxResetTimeout, G.unit, G.unit, G.unit, G.unit) + in[F, G](maxFailures, resetTimeout, backoff, maxResetTimeout, exceptionFilter, G.unit, G.unit, G.unit, G.unit) } /** @@ -303,6 +309,8 @@ object CircuitBreaker { * default implementations. * @param maxResetTimeout is the maximum timeout the circuit breaker * is allowed to use when applying the `backoff` + * @param exceptionFilter a predicate that returns true for exceptions which should trigger the circuitbreaker, + * and false for those which should not (ie be treated the same as success) * * @param onRejected is for signaling rejected tasks * @param onClosed is for signaling a transition to `Closed` @@ -314,6 +322,7 @@ object CircuitBreaker { resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration, maxResetTimeout: Duration, + exceptionFilter: Throwable => Boolean, onRejected: F[Unit], onClosed: F[Unit], onHalfOpen: F[Unit], @@ -327,6 +336,7 @@ object CircuitBreaker { resetTimeout, backoff, maxResetTimeout, + exceptionFilter, onRejected, onClosed, onHalfOpen, @@ -355,6 +365,8 @@ object CircuitBreaker { * default implementations. * @param maxResetTimeout is the maximum timeout the circuit breaker * is allowed to use when applying the `backoff` + * @param exceptionFilter a predicate that returns true for exceptions which should trigger the circuitbreaker, + * and false for those which should not (ie be treated the same as success) * * @param onRejected is for signaling rejected tasks * @param onClosed is for signaling a transition to `Closed` @@ -366,6 +378,7 @@ object CircuitBreaker { resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration, maxResetTimeout: Duration, + exceptionFilter: Throwable => Boolean, onRejected: G[Unit], onClosed: G[Unit], onHalfOpen: G[Unit], @@ -378,6 +391,7 @@ object CircuitBreaker { resetTimeout, backoff, maxResetTimeout, + exceptionFilter, onRejected, onClosed, onHalfOpen, @@ -397,6 +411,7 @@ object CircuitBreaker { resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration, maxResetTimeout: Duration, + exceptionFilter: Throwable => Boolean, onRejected: G[Unit], onClosed: G[Unit], onHalfOpen: G[Unit], @@ -407,6 +422,7 @@ object CircuitBreaker { resetTimeout, backoff, maxResetTimeout, + exceptionFilter, onRejected, onClosed, onHalfOpen, @@ -510,6 +526,7 @@ object CircuitBreaker { resetTimeout: FiniteDuration, backoff: FiniteDuration => FiniteDuration, maxResetTimeout: Duration, + exceptionFilter: Throwable => Boolean, onRejected: F[Unit], onClosed: F[Unit], onHalfOpen: F[Unit], @@ -523,9 +540,7 @@ object CircuitBreaker { require(maxResetTimeout > Duration.Zero, "maxResetTimeout > 0") - def state: F[CircuitBreaker.State] = - ref.get - + def state: F[CircuitBreaker.State] = ref.get def doOnRejected(callback: F[Unit]): CircuitBreaker[F] = { val onRejected = this.onRejected.flatMap(_ => callback) @@ -535,6 +550,7 @@ object CircuitBreaker { resetTimeout = resetTimeout, backoff = backoff, maxResetTimeout = maxResetTimeout, + exceptionFilter = exceptionFilter, onRejected = onRejected, onClosed = onClosed, onHalfOpen = onHalfOpen, @@ -549,6 +565,7 @@ object CircuitBreaker { resetTimeout = resetTimeout, backoff = backoff, maxResetTimeout = maxResetTimeout, + exceptionFilter = exceptionFilter, onRejected = onRejected, onClosed = onClosed, onHalfOpen = onHalfOpen, @@ -563,6 +580,7 @@ object CircuitBreaker { resetTimeout = resetTimeout, backoff = backoff, maxResetTimeout = maxResetTimeout, + exceptionFilter = exceptionFilter, onRejected = onRejected, onClosed = onClosed, onHalfOpen = onHalfOpen, @@ -578,6 +596,7 @@ object CircuitBreaker { resetTimeout = resetTimeout, backoff = backoff, maxResetTimeout = maxResetTimeout, + exceptionFilter = exceptionFilter, onRejected = onRejected, onClosed = onClosed, onHalfOpen = onHalfOpen, @@ -593,14 +612,15 @@ object CircuitBreaker { case HalfOpen => (ClosedZero, onClosed.attempt.void) case Open(_,_) => (ClosedZero, onClosed.attempt.void) }.flatten - case Outcome.Errored(_) => + case Outcome.Errored(e) => Temporal[F].monotonic.map(_.toMillis).flatMap { now => - ref.modify { - case Closed(failures) => - val count = failures + 1 - if (count >= maxFailures) (Open(now, resetTimeout), onOpen.attempt.void) - else (Closed(count), Applicative[F].unit) + case closed @ Closed(failures) => + if (exceptionFilter(e)) { + val count = failures + 1 + if (count >= maxFailures) (Open(now, resetTimeout), onOpen.attempt.void) + else (Closed(count), Applicative[F].unit) + } else (closed, Applicative[F].unit) case open: Open => (open, Applicative[F].unit) case HalfOpen => (HalfOpen, Applicative[F].unit) }.flatten @@ -631,7 +651,9 @@ object CircuitBreaker { def resetOnSuccess(poll: Poll[F]): F[A] = { poll(fa).guaranteeCase { case Outcome.Succeeded(_) => ref.set(ClosedZero) >> onClosed.attempt.void - case Outcome.Errored(_) => ref.set(nextBackoff(open, now)) >> onOpen.attempt.void + case Outcome.Errored(e) => + if (exceptionFilter(e)) ref.set(nextBackoff(open, now)) >> onOpen.attempt.void + else ref.set(ClosedZero) >> onClosed.attempt.void case Outcome.Canceled() => ref.modify{ case HalfOpen => (open, onOpen.attempt.void) case closed: Closed => (closed, F.unit) diff --git a/core/src/test/scala/io/chrisdavenport/circuit/CircuitBreakerTests.scala b/core/src/test/scala/io/chrisdavenport/circuit/CircuitBreakerTests.scala index 44d1120..4057dab 100644 --- a/core/src/test/scala/io/chrisdavenport/circuit/CircuitBreakerTests.scala +++ b/core/src/test/scala/io/chrisdavenport/circuit/CircuitBreakerTests.scala @@ -28,7 +28,6 @@ import cats.syntax.all._ import scala.concurrent.duration._ import cats.effect._ // import cats.effect.syntax._ -import cats.effect.unsafe._ // import catalysts.Platform import munit.CatsEffectSuite @@ -36,9 +35,6 @@ import munit.CatsEffectSuite class CircuitBreakerTests extends CatsEffectSuite { private val Tries = 10000 //if (Platform.isJvm) 10000 else 5000 - - implicit val runtime: IORuntime = cats.effect.unsafe.IORuntime.global - private def mkBreaker() = CircuitBreaker.in[SyncIO, IO]( maxFailures = 5, resetTimeout = 1.minute @@ -335,4 +331,24 @@ class CircuitBreakerTests extends CatsEffectSuite { test } + + test("should only count allowed exceptions") { + case class MyException(foo: String) extends Throwable + + for { + circuitBreaker <- CircuitBreaker.of[IO](maxFailures = 1, resetTimeout = 10.seconds, exceptionFilter = !_.isInstanceOf[MyException]) + action = circuitBreaker.protect(IO.raiseError(MyException("Boom!"))).attempt + _ <- action >> action >> action >> action + _ <- circuitBreaker.state.map { + case _: CircuitBreaker.Closed => assert(true) + case _ => assert(false) + } + badAction = circuitBreaker.protect(IO.raiseError(new RuntimeException("Boom!"))).attempt + _ <- badAction >> badAction >> badAction >> badAction + _ <- circuitBreaker.state.map { + case _: CircuitBreaker.Open => assert(true) + case _ => assert(false) + } + } yield () + } }