Skip to content

Commit

Permalink
Merge pull request #3101 from mprevel/http4s_context_routes
Browse files Browse the repository at this point in the history
Http4s context routes support in server interpreter
  • Loading branch information
adamw authored Aug 18, 2023
2 parents b3e7a79 + dc8dd09 commit f149c99
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 14 deletions.
31 changes: 31 additions & 0 deletions doc/server/http4s.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,37 @@ val routes = Http4sServerInterpreter[IO]().toRoutes(sseEndpoint.serverLogicSucce
))
```

## Accessing http4s context

If you'd like to access context provided by an http4s middleware, e.g. with authentication data, this can be done
with a dedicated context-extracting input, `.contextIn`. Endpoints using such input need then to be interpreted to
`org.http4s.ContextRoutes` (also known by its type alias `AuthedRoutes`) using the `.toContextRoutes` method.

For example:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.http4s._
import cats.effect.IO
import org.http4s.ContextRoutes

case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s middleware

def countCharacters(in: (String, SomeCtx)): IO[Either[Unit, Int]] =
IO.pure(
if(in._2.actionAllowed) Right[Unit, Int](in._1.length) else Left[Unit, Int](())
)

// the .contextIn extension method is imported from the sttp.tapir.server.http4s package
// the Context[SomeCtx] capability requirement requires interpretation to be done using .toContextRoutes
val countCharactersEndpoint: PublicEndpoint[(String, SomeCtx), Unit, Int, Context[SomeCtx]] =
endpoint.in(stringBody).contextIn[SomeCtx]().out(plainBody[Int])

val countCharactersRoutes: ContextRoutes[SomeCtx, IO] =
Http4sServerInterpreter[IO]()
.toContextRoutes(countCharactersEndpoint.serverLogic(countCharacters _))
```

## Configuration

The interpreter can be configured by providing an `Http4sServerOptions` value, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,31 @@ package sttp.tapir.server.http4s

import cats.data.{Kleisli, OptionT}
import cats.effect.Async
import cats.effect.std.Queue
import cats.implicits._
import fs2.{Pipe, Stream}
import org.http4s._
import org.http4s.headers.`Content-Length`
import org.http4s.server.websocket.WebSocketBuilder2
import org.http4s.websocket.WebSocketFrame
import org.typelevel.ci.CIString
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir._
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.model.ServerResponse

import scala.reflect.ClassTag

class Http4sInvalidWebSocketUse(val message: String) extends Exception

trait Http4sServerInterpreter[F[_]] {
/** A capability that is used by endpoints, when they need to access the http4s-provided context. Such a requirement can be added using the
* [[RichHttp4sEndpoint.contextIn]] method.
*/
trait Context[T]

trait Http4sServerInterpreter[F[_]] {
implicit def fa: Async[F]

def http4sServerOptions: Http4sServerOptions[F] = Http4sServerOptions.default[F]
Expand All @@ -42,28 +46,62 @@ trait Http4sServerInterpreter[F[_]] {
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]
): WebSocketBuilder2[F] => HttpRoutes[F] = wsb => toRoutes(serverEndpoints, Some(wsb))

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
def toContextRoutes[T: ClassTag](se: ServerEndpoint[Fs2Streams[F] with Context[T], F]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], List(se), None)

def toContextRoutes[T: ClassTag](ses: List[ServerEndpoint[Fs2Streams[F] with Context[T], F]]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], ses, None)

private def createInterpreter[T](
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]]
): ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]] = {
implicit val monad: CatsMonadError[F] = new CatsMonadError[F]
implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F]

val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]](
new ServerInterpreter(
FilterServerEndpoints(serverEndpoints),
new Http4sRequestBody[F](http4sServerOptions),
new Http4sToResponseBody[F](http4sServerOptions),
RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints),
http4sServerOptions.deleteFile
)
}

private def toResponse[T](
interpreter: ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]],
serverRequest: Http4sServerRequest[F],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): OptionT[F, Response[F]] =
OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (req: Request[F]) =>
val serverRequest = Http4sServerRequest(req)
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})
private def toContextRoutes[T](
contextAttributeKey: AttributeKey[T],
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): ContextRoutes[T, F] = {
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (contextRequest: ContextRequest[F, T]) =>
val serverRequest = Http4sServerRequest(
contextRequest.req,
AttributeMap.Empty.put(contextAttributeKey, contextRequest.context)
)
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ import org.http4s.EntityBody
import org.http4s.websocket.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.sse.ServerSentEvent
import sttp.tapir.{CodecFormat, StreamBodyIO, streamTextBody}
import sttp.tapir.model.ServerRequest
import sttp.tapir.typelevel.ParamConcat
import sttp.tapir.{AttributeKey, CodecFormat, Endpoint, StreamBodyIO, extractFromRequest, streamTextBody}

import java.nio.charset.Charset
import scala.reflect.ClassTag

package object http4s {
// either a web socket, or a stream with optional length (if known)
Expand All @@ -18,4 +21,31 @@ package object http4s {
streamTextBody(fs2Streams)(CodecFormat.TextEventStream(), Some(Charset.forName("UTF-8")))
.map(Http4sServerSentEvents.parseBytesToSSE[F])(Http4sServerSentEvents.serialiseSSEToBytes[F])
}

private[http4s] def contextAttributeKey[T: ClassTag]: AttributeKey[T] = new AttributeKey(implicitly[ClassTag[T]].runtimeClass.getName)

implicit class RichHttp4sEndpoint[A, I, E, O, R](e: Endpoint[A, I, E, O, R]) {

/** Access the context provided by an http4s middleware, such as authentication data.
*
* Interpreting endpoints which access the http4s context requires the usage of the [[Http4sServerInterpreter.toContextRoutes]]
* method. This then yields a [[org.http4s.ContextRoutes]] instance, which needs to be correctly mounted in the http4s router.
*
* Note that the correct syntax for adding the context input includes `()` after the method invocation, to properly infer types and
* capture implicit parameters, e.g. `myEndpoint.contextIn[Auth]()`.
*/
def contextIn[T]: AddContextInput[T] = new AddContextInput[T]

class AddContextInput[T] {
def apply[IT]()(implicit concat: ParamConcat.Aux[I, T, IT], ct: ClassTag[T]): Endpoint[A, IT, E, O, R with Context[T]] = {
val attribute = contextAttributeKey[T]
e.in(extractFromRequest[T] { (req: ServerRequest) =>
req
.attribute(attribute)
// should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes
.getOrElse(throw new RuntimeException(s"context ${attribute.typeName} not found in the request"))
})
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package sttp.tapir.server.http4s

import cats.data._
import cats.effect._
import cats.effect.unsafe.implicits.global
import cats.syntax.all._
import fs2.Pipe
import org.http4s.blaze.server.BlazeServerBuilder
import org.http4s.server.Router
import org.http4s.server.ContextMiddleware
import org.http4s.ContextRoutes
import org.http4s.HttpRoutes
import org.scalatest.OptionValues
import org.scalatest.matchers.should.Matchers._
import sttp.capabilities.WebSockets
Expand Down Expand Up @@ -50,6 +54,32 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
}
.unsafeRunSync()
},
Test("should work with a router and context routes in a context") {
val expectedContext: String = "Hello World!" // the context we expect http4s to provide to the endpoint

val e: Endpoint[Unit, String, Unit, String, Context[String]] =
endpoint.get.in("test" / "router").contextIn[String]().out(stringBody)

val routesWithContext: ContextRoutes[String, IO] =
Http4sServerInterpreter[IO]()
// server logic is to return the context as is
.toContextRoutes(e.serverLogicSuccess(ctx => IO.pure(ctx)))

// middleware to add the context to each request (so here string constant)
val middleware: ContextMiddleware[IO, String] =
ContextMiddleware.const(expectedContext)

BlazeServerBuilder[IO]
.withExecutionContext(ExecutionContext.global)
.bindHttp(0, "localhost")
.withHttpApp(Router("/api" -> middleware(routesWithContext)).orNotFound)
.resource
.use { server =>
val port = server.address.getPort
basicRequest.get(uri"http://localhost:$port/api/test/router").send(backend).map(_.body shouldBe Right(expectedContext))
}
.unsafeRunSync()
},
createServerTest.testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain]
Expand Down

0 comments on commit f149c99

Please sign in to comment.