Skip to content

Commit

Permalink
Alternative implementation of contextual endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Aug 17, 2023
1 parent f0bff06 commit dc8dd09
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 104 deletions.
56 changes: 31 additions & 25 deletions doc/server/http4s.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,6 @@ It's completely feasible that some part of the input is read using a http4s wrap
with the tapir endpoint descriptions. Moreover, "edge-case endpoints", which require some special logic not expressible
using tapir, can be always implemented directly using http4s.


If you need a `ContextRoutes` (or its type alias `AuthedRoutes`) with a `SomeCtx` context intead of a `HttpRoutes`:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.http4s.Http4sServerInterpreter
import sttp.tapir.server.http4s.InputWithContext
import cats.effect.IO
import org.http4s.ContextRoutes

case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s

def countCharacters(in: InputWithContext[String, SomeCtx]): IO[Either[Unit, Int]] =
IO.pure(
if(in.context.actionAllowed) Right[Unit, Int](in.input.length) else Left[Unit, Int](())
)

val countCharactersEndpoint: PublicEndpoint[String, Unit, Int, Any] =
endpoint.in(stringBody).out(plainBody[Int])
val countCharactersRoutes: ContextRoutes[SomeCtx, IO] =
Http4sServerInterpreter[IO]()
.withContext[SomeCtx]() // you may give it a name (default to "defaultContext")
.toContextRoutes(countCharactersEndpoint)(_.serverLogic(countCharacters _))
```

## Streaming

The http4s interpreter accepts streaming bodies of type `Stream[F, Byte]`, as described by the `Fs2Streams`
Expand Down Expand Up @@ -139,6 +114,37 @@ val routes = Http4sServerInterpreter[IO]().toRoutes(sseEndpoint.serverLogicSucce
))
```

## Accessing http4s context

If you'd like to access context provided by an http4s middleware, e.g. with authentication data, this can be done
with a dedicated context-extracting input, `.contextIn`. Endpoints using such input need then to be interpreted to
`org.http4s.ContextRoutes` (also known by its type alias `AuthedRoutes`) using the `.toContextRoutes` method.

For example:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.http4s._
import cats.effect.IO
import org.http4s.ContextRoutes

case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s middleware

def countCharacters(in: (String, SomeCtx)): IO[Either[Unit, Int]] =
IO.pure(
if(in._2.actionAllowed) Right[Unit, Int](in._1.length) else Left[Unit, Int](())
)

// the .contextIn extension method is imported from the sttp.tapir.server.http4s package
// the Context[SomeCtx] capability requirement requires interpretation to be done using .toContextRoutes
val countCharactersEndpoint: PublicEndpoint[(String, SomeCtx), Unit, Int, Context[SomeCtx]] =
endpoint.in(stringBody).contextIn[SomeCtx]().out(plainBody[Int])

val countCharactersRoutes: ContextRoutes[SomeCtx, IO] =
Http4sServerInterpreter[IO]()
.toContextRoutes(countCharactersEndpoint.serverLogic(countCharacters _))
```

## Configuration

The interpreter can be configured by providing an `Http4sServerOptions` value, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,31 @@ package sttp.tapir.server.http4s

import cats.data.{Kleisli, OptionT}
import cats.effect.Async
import cats.effect.std.Queue
import cats.implicits._
import fs2.{Pipe, Stream}
import org.http4s._
import org.http4s.headers.`Content-Length`
import org.http4s.server.websocket.WebSocketBuilder2
import org.http4s.websocket.WebSocketFrame
import org.typelevel.ci.CIString
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir._
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.model.ServerResponse

import scala.reflect.ClassTag

class Http4sInvalidWebSocketUse(val message: String) extends Exception

final case class InputWithContext[In, Ctx](input: In, context: Ctx)
/** A capability that is used by endpoints, when they need to access the http4s-provided context. Such a requirement can be added using the
* [[RichHttp4sEndpoint.contextIn]] method.
*/
trait Context[T]

trait Http4sServerInterpreter[F[_]] {

// builder to create a ContextRoutes[Ctx, F] instead of a HttpRoutes[F]
// allowing to delegate this context retieval to http4s (eg. for authentication)
// the context is put in the request attributes, then retrieved and passed to the endpoint
final class ContextRoutesBuilder[Ctx](name: String) {

private val attrKey = new AttributeKey[Ctx](name)

def toContextRoutes[S, I, E, O, R](
endpoint: Endpoint[S, I, E, O, R],
f: Endpoint[S, InputWithContext[I, Ctx], E, O, R] => List[ServerEndpoint[Fs2Streams[F], F]]
)(implicit dummy: DummyImplicit): ContextRoutes[Ctx, F] = {

val endpointWithContext =
endpoint
.in(extractFromRequest { (req: ServerRequest) =>
req
.attribute(attrKey)
// should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes
.getOrElse(throw new RuntimeException(s"context ${name} not found in the request"))
})
.mapIn(tuple => (InputWithContext.apply[I, Ctx](_, _)).tupled(tuple))(tuple => (tuple.input, tuple.context))

innerContextRoutes[Ctx](attrKey, f(endpointWithContext), None)
}

def toContextRoutes[S, I, E, O, R](endpoint: Endpoint[S, I, E, O, R])(
f: Endpoint[S, InputWithContext[I, Ctx], E, O, R] => ServerEndpoint[Fs2Streams[F], F]
): ContextRoutes[Ctx, F] =
toContextRoutes(endpoint, (e: Endpoint[S, InputWithContext[I, Ctx], E, O, R]) => List(f(e)))
}

implicit def fa: Async[F]

def http4sServerOptions: Http4sServerOptions[F] = Http4sServerOptions.default[F]
Expand All @@ -77,62 +46,62 @@ trait Http4sServerInterpreter[F[_]] {
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]
): WebSocketBuilder2[F] => HttpRoutes[F] = wsb => toRoutes(serverEndpoints, Some(wsb))

def withContext[Ctx](name: String = "defaultContext"): ContextRoutesBuilder[Ctx] =
new ContextRoutesBuilder[Ctx](name)
def toContextRoutes[T: ClassTag](se: ServerEndpoint[Fs2Streams[F] with Context[T], F]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], List(se), None)

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
def toContextRoutes[T: ClassTag](ses: List[ServerEndpoint[Fs2Streams[F] with Context[T], F]]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], ses, None)

private def createInterpreter[T](
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]]
): ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]] = {
implicit val monad: CatsMonadError[F] = new CatsMonadError[F]
implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F]

val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]](
new ServerInterpreter(
FilterServerEndpoints(serverEndpoints),
new Http4sRequestBody[F](http4sServerOptions),
new Http4sToResponseBody[F](http4sServerOptions),
RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints),
http4sServerOptions.deleteFile
)
}

private def toResponse[T](
interpreter: ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]],
serverRequest: Http4sServerRequest[F],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): OptionT[F, Response[F]] =
OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (req: Request[F]) =>
val serverRequest = Http4sServerRequest(req)

OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

private def innerContextRoutes[T](
attributeKey: AttributeKey[T],
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
private def toContextRoutes[T](
contextAttributeKey: AttributeKey[T],
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): ContextRoutes[T, F] = {
implicit val monad: CatsMonadError[F] = new CatsMonadError[F]
implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F]

val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]](
FilterServerEndpoints(serverEndpoints),
new Http4sRequestBody[F](http4sServerOptions),
new Http4sToResponseBody[F](http4sServerOptions),
RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints),
http4sServerOptions.deleteFile
)
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (contextRequest: ContextRequest[F, T]) =>
val serverRequest =
Http4sServerRequest(
contextRequest.req,
AttributeMap.Empty
.put(attributeKey, contextRequest.context)
)

OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})
val serverRequest = Http4sServerRequest(
contextRequest.req,
AttributeMap.Empty.put(contextAttributeKey, contextRequest.context)
)
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ import org.http4s.EntityBody
import org.http4s.websocket.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.sse.ServerSentEvent
import sttp.tapir.{CodecFormat, StreamBodyIO, streamTextBody}
import sttp.tapir.model.ServerRequest
import sttp.tapir.typelevel.ParamConcat
import sttp.tapir.{AttributeKey, CodecFormat, Endpoint, StreamBodyIO, extractFromRequest, streamTextBody}

import java.nio.charset.Charset
import scala.reflect.ClassTag

package object http4s {
// either a web socket, or a stream with optional length (if known)
Expand All @@ -18,4 +21,31 @@ package object http4s {
streamTextBody(fs2Streams)(CodecFormat.TextEventStream(), Some(Charset.forName("UTF-8")))
.map(Http4sServerSentEvents.parseBytesToSSE[F])(Http4sServerSentEvents.serialiseSSEToBytes[F])
}

private[http4s] def contextAttributeKey[T: ClassTag]: AttributeKey[T] = new AttributeKey(implicitly[ClassTag[T]].runtimeClass.getName)

implicit class RichHttp4sEndpoint[A, I, E, O, R](e: Endpoint[A, I, E, O, R]) {

/** Access the context provided by an http4s middleware, such as authentication data.
*
* Interpreting endpoints which access the http4s context requires the usage of the [[Http4sServerInterpreter.toContextRoutes]]
* method. This then yields a [[org.http4s.ContextRoutes]] instance, which needs to be correctly mounted in the http4s router.
*
* Note that the correct syntax for adding the context input includes `()` after the method invocation, to properly infer types and
* capture implicit parameters, e.g. `myEndpoint.contextIn[Auth]()`.
*/
def contextIn[T]: AddContextInput[T] = new AddContextInput[T]

class AddContextInput[T] {
def apply[IT]()(implicit concat: ParamConcat.Aux[I, T, IT], ct: ClassTag[T]): Endpoint[A, IT, E, O, R with Context[T]] = {
val attribute = contextAttributeKey[T]
e.in(extractFromRequest[T] { (req: ServerRequest) =>
req
.attribute(attribute)
// should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes
.getOrElse(throw new RuntimeException(s"context ${attribute.typeName} not found in the request"))
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,13 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
Test("should work with a router and context routes in a context") {
val expectedContext: String = "Hello World!" // the context we expect http4s to provide to the endpoint

def serverFn(in: InputWithContext[Unit, String]) = IO.pure(Right[Unit, String](in.context))

val e = endpoint.get.in("test" / "router").out(stringBody)
val e: Endpoint[Unit, String, Unit, String, Context[String]] =
endpoint.get.in("test" / "router").contextIn[String]().out(stringBody)

val routesWithContext: ContextRoutes[String, IO] =
Http4sServerInterpreter[IO]()
.withContext[String]()
// server logic is to return the context as is
.toContextRoutes(e)(_.serverLogic[IO](serverFn _))
.toContextRoutes(e.serverLogicSuccess(ctx => IO.pure(ctx)))

// middleware to add the context to each request (so here string constant)
val middleware: ContextMiddleware[IO, String] =
Expand Down

0 comments on commit dc8dd09

Please sign in to comment.