diff --git a/doc/server/http4s.md b/doc/server/http4s.md index 894f20b1e8..abd598e3b8 100644 --- a/doc/server/http4s.md +++ b/doc/server/http4s.md @@ -41,31 +41,6 @@ It's completely feasible that some part of the input is read using a http4s wrap with the tapir endpoint descriptions. Moreover, "edge-case endpoints", which require some special logic not expressible using tapir, can be always implemented directly using http4s. - -If you need a `ContextRoutes` (or its type alias `AuthedRoutes`) with a `SomeCtx` context intead of a `HttpRoutes`: - -```scala mdoc:compile-only -import sttp.tapir._ -import sttp.tapir.server.http4s.Http4sServerInterpreter -import sttp.tapir.server.http4s.InputWithContext -import cats.effect.IO -import org.http4s.ContextRoutes - -case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s - -def countCharacters(in: InputWithContext[String, SomeCtx]): IO[Either[Unit, Int]] = - IO.pure( - if(in.context.actionAllowed) Right[Unit, Int](in.input.length) else Left[Unit, Int](()) - ) - -val countCharactersEndpoint: PublicEndpoint[String, Unit, Int, Any] = - endpoint.in(stringBody).out(plainBody[Int]) -val countCharactersRoutes: ContextRoutes[SomeCtx, IO] = - Http4sServerInterpreter[IO]() - .withContext[SomeCtx]() // you may give it a name (default to "defaultContext") - .toContextRoutes(countCharactersEndpoint)(_.serverLogic(countCharacters _)) -``` - ## Streaming The http4s interpreter accepts streaming bodies of type `Stream[F, Byte]`, as described by the `Fs2Streams` @@ -139,6 +114,37 @@ val routes = Http4sServerInterpreter[IO]().toRoutes(sseEndpoint.serverLogicSucce )) ``` +## Accessing http4s context + +If you'd like to access context provided by an http4s middleware, e.g. with authentication data, this can be done +with a dedicated context-extracting input, `.contextIn`. Endpoints using such input need then to be interpreted to +`org.http4s.ContextRoutes` (also known by its type alias `AuthedRoutes`) using the `.toContextRoutes` method. + +For example: + +```scala mdoc:compile-only +import sttp.tapir._ +import sttp.tapir.server.http4s._ +import cats.effect.IO +import org.http4s.ContextRoutes + +case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s middleware + +def countCharacters(in: (String, SomeCtx)): IO[Either[Unit, Int]] = + IO.pure( + if(in._2.actionAllowed) Right[Unit, Int](in._1.length) else Left[Unit, Int](()) + ) + +// the .contextIn extension method is imported from the sttp.tapir.server.http4s package +// the Context[SomeCtx] capability requirement requires interpretation to be done using .toContextRoutes +val countCharactersEndpoint: PublicEndpoint[(String, SomeCtx), Unit, Int, Context[SomeCtx]] = + endpoint.in(stringBody).contextIn[SomeCtx]().out(plainBody[Int]) + +val countCharactersRoutes: ContextRoutes[SomeCtx, IO] = + Http4sServerInterpreter[IO]() + .toContextRoutes(countCharactersEndpoint.serverLogic(countCharacters _)) +``` + ## Configuration The interpreter can be configured by providing an `Http4sServerOptions` value, see diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala index fff3f834fe..6001c7554c 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sServerInterpreter.scala @@ -2,62 +2,31 @@ package sttp.tapir.server.http4s import cats.data.{Kleisli, OptionT} import cats.effect.Async -import cats.effect.std.Queue import cats.implicits._ -import fs2.{Pipe, Stream} import org.http4s._ import org.http4s.headers.`Content-Length` import org.http4s.server.websocket.WebSocketBuilder2 -import org.http4s.websocket.WebSocketFrame import org.typelevel.ci.CIString import sttp.capabilities.WebSockets import sttp.capabilities.fs2.Fs2Streams import sttp.tapir._ import sttp.tapir.integ.cats.effect.CatsMonadError -import sttp.tapir.model.ServerRequest import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.model.ServerResponse +import scala.reflect.ClassTag + class Http4sInvalidWebSocketUse(val message: String) extends Exception -final case class InputWithContext[In, Ctx](input: In, context: Ctx) +/** A capability that is used by endpoints, when they need to access the http4s-provided context. Such a requirement can be added using the + * [[RichHttp4sEndpoint.contextIn]] method. + */ +trait Context[T] trait Http4sServerInterpreter[F[_]] { - - // builder to create a ContextRoutes[Ctx, F] instead of a HttpRoutes[F] - // allowing to delegate this context retieval to http4s (eg. for authentication) - // the context is put in the request attributes, then retrieved and passed to the endpoint - final class ContextRoutesBuilder[Ctx](name: String) { - - private val attrKey = new AttributeKey[Ctx](name) - - def toContextRoutes[S, I, E, O, R]( - endpoint: Endpoint[S, I, E, O, R], - f: Endpoint[S, InputWithContext[I, Ctx], E, O, R] => List[ServerEndpoint[Fs2Streams[F], F]] - )(implicit dummy: DummyImplicit): ContextRoutes[Ctx, F] = { - - val endpointWithContext = - endpoint - .in(extractFromRequest { (req: ServerRequest) => - req - .attribute(attrKey) - // should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes - .getOrElse(throw new RuntimeException(s"context ${name} not found in the request")) - }) - .mapIn(tuple => (InputWithContext.apply[I, Ctx](_, _)).tupled(tuple))(tuple => (tuple.input, tuple.context)) - - innerContextRoutes[Ctx](attrKey, f(endpointWithContext), None) - } - - def toContextRoutes[S, I, E, O, R](endpoint: Endpoint[S, I, E, O, R])( - f: Endpoint[S, InputWithContext[I, Ctx], E, O, R] => ServerEndpoint[Fs2Streams[F], F] - ): ContextRoutes[Ctx, F] = - toContextRoutes(endpoint, (e: Endpoint[S, InputWithContext[I, Ctx], E, O, R]) => List(f(e))) - } - implicit def fa: Async[F] def http4sServerOptions: Http4sServerOptions[F] = Http4sServerOptions.default[F] @@ -77,62 +46,62 @@ trait Http4sServerInterpreter[F[_]] { serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]] ): WebSocketBuilder2[F] => HttpRoutes[F] = wsb => toRoutes(serverEndpoints, Some(wsb)) - def withContext[Ctx](name: String = "defaultContext"): ContextRoutesBuilder[Ctx] = - new ContextRoutesBuilder[Ctx](name) + def toContextRoutes[T: ClassTag](se: ServerEndpoint[Fs2Streams[F] with Context[T], F]): ContextRoutes[T, F] = + toContextRoutes(contextAttributeKey[T], List(se), None) - private def toRoutes( - serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], - webSocketBuilder: Option[WebSocketBuilder2[F]] - ): HttpRoutes[F] = { + def toContextRoutes[T: ClassTag](ses: List[ServerEndpoint[Fs2Streams[F] with Context[T], F]]): ContextRoutes[T, F] = + toContextRoutes(contextAttributeKey[T], ses, None) + + private def createInterpreter[T]( + serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]] + ): ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]] = { implicit val monad: CatsMonadError[F] = new CatsMonadError[F] implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F] - val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]]( + new ServerInterpreter( FilterServerEndpoints(serverEndpoints), new Http4sRequestBody[F](http4sServerOptions), new Http4sToResponseBody[F](http4sServerOptions), RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints), http4sServerOptions.deleteFile ) + } + + private def toResponse[T]( + interpreter: ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]], + serverRequest: Http4sServerRequest[F], + webSocketBuilder: Option[WebSocketBuilder2[F]] + ): OptionT[F, Response[F]] = + OptionT(interpreter(serverRequest).flatMap { + case _: RequestResult.Failure => none.pure[F] + case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some) + }) + + private def toRoutes( + serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], + webSocketBuilder: Option[WebSocketBuilder2[F]] + ): HttpRoutes[F] = { + val interpreter = createInterpreter(serverEndpoints) Kleisli { (req: Request[F]) => val serverRequest = Http4sServerRequest(req) - - OptionT(interpreter(serverRequest).flatMap { - case _: RequestResult.Failure => none.pure[F] - case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some) - }) + toResponse(interpreter, serverRequest, webSocketBuilder) } } - private def innerContextRoutes[T]( - attributeKey: AttributeKey[T], - serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], + private def toContextRoutes[T]( + contextAttributeKey: AttributeKey[T], + serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]], webSocketBuilder: Option[WebSocketBuilder2[F]] ): ContextRoutes[T, F] = { - implicit val monad: CatsMonadError[F] = new CatsMonadError[F] - implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F] - - val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]]( - FilterServerEndpoints(serverEndpoints), - new Http4sRequestBody[F](http4sServerOptions), - new Http4sToResponseBody[F](http4sServerOptions), - RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints), - http4sServerOptions.deleteFile - ) + val interpreter = createInterpreter(serverEndpoints) Kleisli { (contextRequest: ContextRequest[F, T]) => - val serverRequest = - Http4sServerRequest( - contextRequest.req, - AttributeMap.Empty - .put(attributeKey, contextRequest.context) - ) - - OptionT(interpreter(serverRequest).flatMap { - case _: RequestResult.Failure => none.pure[F] - case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some) - }) + val serverRequest = Http4sServerRequest( + contextRequest.req, + AttributeMap.Empty.put(contextAttributeKey, contextRequest.context) + ) + toResponse(interpreter, serverRequest, webSocketBuilder) } } diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/package.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/package.scala index a396a51e3a..3820b7ca51 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/package.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/package.scala @@ -5,9 +5,12 @@ import org.http4s.EntityBody import org.http4s.websocket.WebSocketFrame import sttp.capabilities.fs2.Fs2Streams import sttp.model.sse.ServerSentEvent -import sttp.tapir.{CodecFormat, StreamBodyIO, streamTextBody} +import sttp.tapir.model.ServerRequest +import sttp.tapir.typelevel.ParamConcat +import sttp.tapir.{AttributeKey, CodecFormat, Endpoint, StreamBodyIO, extractFromRequest, streamTextBody} import java.nio.charset.Charset +import scala.reflect.ClassTag package object http4s { // either a web socket, or a stream with optional length (if known) @@ -18,4 +21,31 @@ package object http4s { streamTextBody(fs2Streams)(CodecFormat.TextEventStream(), Some(Charset.forName("UTF-8"))) .map(Http4sServerSentEvents.parseBytesToSSE[F])(Http4sServerSentEvents.serialiseSSEToBytes[F]) } + + private[http4s] def contextAttributeKey[T: ClassTag]: AttributeKey[T] = new AttributeKey(implicitly[ClassTag[T]].runtimeClass.getName) + + implicit class RichHttp4sEndpoint[A, I, E, O, R](e: Endpoint[A, I, E, O, R]) { + + /** Access the context provided by an http4s middleware, such as authentication data. + * + * Interpreting endpoints which access the http4s context requires the usage of the [[Http4sServerInterpreter.toContextRoutes]] + * method. This then yields a [[org.http4s.ContextRoutes]] instance, which needs to be correctly mounted in the http4s router. + * + * Note that the correct syntax for adding the context input includes `()` after the method invocation, to properly infer types and + * capture implicit parameters, e.g. `myEndpoint.contextIn[Auth]()`. + */ + def contextIn[T]: AddContextInput[T] = new AddContextInput[T] + + class AddContextInput[T] { + def apply[IT]()(implicit concat: ParamConcat.Aux[I, T, IT], ct: ClassTag[T]): Endpoint[A, IT, E, O, R with Context[T]] = { + val attribute = contextAttributeKey[T] + e.in(extractFromRequest[T] { (req: ServerRequest) => + req + .attribute(attribute) + // should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes + .getOrElse(throw new RuntimeException(s"context ${attribute.typeName} not found in the request")) + }) + } + } + } } diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index c441079e15..bf80ef1cad 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -57,15 +57,13 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi Test("should work with a router and context routes in a context") { val expectedContext: String = "Hello World!" // the context we expect http4s to provide to the endpoint - def serverFn(in: InputWithContext[Unit, String]) = IO.pure(Right[Unit, String](in.context)) - - val e = endpoint.get.in("test" / "router").out(stringBody) + val e: Endpoint[Unit, String, Unit, String, Context[String]] = + endpoint.get.in("test" / "router").contextIn[String]().out(stringBody) val routesWithContext: ContextRoutes[String, IO] = Http4sServerInterpreter[IO]() - .withContext[String]() // server logic is to return the context as is - .toContextRoutes(e)(_.serverLogic[IO](serverFn _)) + .toContextRoutes(e.serverLogicSuccess(ctx => IO.pure(ctx))) // middleware to add the context to each request (so here string constant) val middleware: ContextMiddleware[IO, String] =