From 1c6d3060a069b312a4d2b6317d378ac5bc8e3b9e Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Wed, 15 Dec 2021 09:51:40 +0900 Subject: [PATCH] Make RequestInterceptor more like old ContextWrapper (#1208) --- .../main/scala/example/akkahttp/AuthExampleApp.scala | 6 ++---- .../src/main/scala/example/play/AuthExampleApp.scala | 4 ++-- .../caliban/interop/tapir/RequestInterceptor.scala | 8 ++++---- .../scala/caliban/interop/tapir/TapirAdapter.scala | 10 +++++----- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/examples/src/main/scala/example/akkahttp/AuthExampleApp.scala b/examples/src/main/scala/example/akkahttp/AuthExampleApp.scala index 69ef7a52d..287fffa31 100644 --- a/examples/src/main/scala/example/akkahttp/AuthExampleApp.scala +++ b/examples/src/main/scala/example/akkahttp/AuthExampleApp.scala @@ -25,13 +25,11 @@ object AuthExampleApp extends App { type Auth = Has[FiberRef[Option[AuthToken]]] object AuthInterceptor extends RequestInterceptor[Auth] { - override def apply[R <: Auth]( - request: ServerRequest - ): ZIO[R, StatusCode, Unit] = + override def apply[R <: Auth, A](request: ServerRequest)(effect: ZIO[R, StatusCode, A]): ZIO[R, StatusCode, A] = request.headers.collectFirst { case header if header.is("token") => header.value } match { - case Some(token) => ZIO.accessM[Auth](_.get.set(Some(AuthToken(token)))) + case Some(token) => ZIO.accessM[Auth](_.get.set(Some(AuthToken(token)))) *> effect case _ => ZIO.fail(StatusCode.Forbidden) } } diff --git a/examples/src/main/scala/example/play/AuthExampleApp.scala b/examples/src/main/scala/example/play/AuthExampleApp.scala index f0a122a3f..8ffacd3c6 100644 --- a/examples/src/main/scala/example/play/AuthExampleApp.scala +++ b/examples/src/main/scala/example/play/AuthExampleApp.scala @@ -30,9 +30,9 @@ object AuthExampleApp extends App { implicit val executionContext: ExecutionContextExecutor = system.dispatcher object AuthWrapper extends RequestInterceptor[Auth] { - def apply[R1 <: Auth](request: ServerRequest): ZIO[R1, StatusCode, Unit] = + override def apply[R <: Auth, A](request: ServerRequest)(effect: ZIO[R, StatusCode, A]): ZIO[R, StatusCode, A] = request.header("token") match { - case Some(token) => ZIO.accessM[Auth](_.get.set(Some(AuthToken(token)))) + case Some(token) => ZIO.accessM[Auth](_.get.set(Some(AuthToken(token)))) *> effect case None => ZIO.fail(StatusCode.Forbidden) } } diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/RequestInterceptor.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/RequestInterceptor.scala index 00f501d6e..08e1629cb 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/RequestInterceptor.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/RequestInterceptor.scala @@ -9,16 +9,16 @@ import zio.ZIO * query execution or injecting context into ZIO environment. */ trait RequestInterceptor[-R] { self => - def apply[R1 <: R](request: ServerRequest): ZIO[R1, StatusCode, Unit] + def apply[R1 <: R, A](request: ServerRequest)(e: ZIO[R1, StatusCode, A]): ZIO[R1, StatusCode, A] def |+|[R1 <: R](that: RequestInterceptor[R1]): RequestInterceptor[R1] = new RequestInterceptor[R1] { - override def apply[R2 <: R1](request: ServerRequest): ZIO[R2, StatusCode, Unit] = - that.apply[R2](request) *> self.apply[R2](request) + override def apply[R2 <: R1, A](request: ServerRequest)(e: ZIO[R2, StatusCode, A]): ZIO[R2, StatusCode, A] = + that.apply[R2, A](request)(self.apply[R2, A](request)(e)) } } object RequestInterceptor { def empty: RequestInterceptor[Any] = new RequestInterceptor[Any] { - override def apply[R](request: ServerRequest): ZIO[R, StatusCode, Unit] = ZIO.unit + override def apply[R, A](request: ServerRequest)(e: ZIO[R, StatusCode, A]): ZIO[R, StatusCode, A] = e } } diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala index 50b8ddc14..c8637fbc7 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala @@ -106,14 +106,15 @@ object TapirAdapter { def logic(request: (GraphQLRequest, ServerRequest)): RIO[R, Either[StatusCode, GraphQLResponse[E]]] = { val (graphQLRequest, serverRequest) = request - (requestInterceptor(serverRequest) *> + requestInterceptor(serverRequest)( interpreter .executeRequest( graphQLRequest, skipValidation = skipValidation, enableIntrospection = enableIntrospection, queryExecution - )).either + ) + ).either } makeHttpEndpoints.map(_.serverLogic(logic)) @@ -147,7 +148,6 @@ object TapirAdapter { val io = for { - _ <- requestInterceptor(serverRequest) rawOperations <- ZIO.fromOption(partsMap.get("operations")) orElseFail StatusCode.BadRequest request <- requestCodec.rawDecode(new String(rawOperations.body, "utf-8")) match { case _: DecodeResult.Failure => ZIO.fail(StatusCode.BadRequest) @@ -191,7 +191,7 @@ object TapirAdapter { .provideSomeLayer[R with Random](uploadQuery.fileHandle.toLayerMany) } yield response - io.either + requestInterceptor(serverRequest)(io).either } makeHttpUploadEndpoint.serverLogic(logic) @@ -288,7 +288,7 @@ object TapirAdapter { } yield pipe makeWebSocketEndpoint.serverLogic[RIO[R, *]](serverRequest => - requestInterceptor(serverRequest).foldM(statusCode => ZIO.left(statusCode), _ => io) + requestInterceptor(serverRequest)(io).catchAll(ZIO.left(_)) ) }