Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Http4s context routes support in server interpreter #3101

Merged
merged 2 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doc/server/http4s.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,37 @@ val routes = Http4sServerInterpreter[IO]().toRoutes(sseEndpoint.serverLogicSucce
))
```

## Accessing http4s context

If you'd like to access context provided by an http4s middleware, e.g. with authentication data, this can be done
with a dedicated context-extracting input, `.contextIn`. Endpoints using such input need then to be interpreted to
`org.http4s.ContextRoutes` (also known by its type alias `AuthedRoutes`) using the `.toContextRoutes` method.

For example:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.http4s._
import cats.effect.IO
import org.http4s.ContextRoutes

case class SomeCtx(actionAllowed: Boolean) // the context expected from http4s middleware

def countCharacters(in: (String, SomeCtx)): IO[Either[Unit, Int]] =
IO.pure(
if(in._2.actionAllowed) Right[Unit, Int](in._1.length) else Left[Unit, Int](())
)

// the .contextIn extension method is imported from the sttp.tapir.server.http4s package
// the Context[SomeCtx] capability requirement requires interpretation to be done using .toContextRoutes
val countCharactersEndpoint: PublicEndpoint[(String, SomeCtx), Unit, Int, Context[SomeCtx]] =
endpoint.in(stringBody).contextIn[SomeCtx]().out(plainBody[Int])

val countCharactersRoutes: ContextRoutes[SomeCtx, IO] =
Http4sServerInterpreter[IO]()
.toContextRoutes(countCharactersEndpoint.serverLogic(countCharacters _))
```

## Configuration

The interpreter can be configured by providing an `Http4sServerOptions` value, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,31 @@ package sttp.tapir.server.http4s

import cats.data.{Kleisli, OptionT}
import cats.effect.Async
import cats.effect.std.Queue
import cats.implicits._
import fs2.{Pipe, Stream}
import org.http4s._
import org.http4s.headers.`Content-Length`
import org.http4s.server.websocket.WebSocketBuilder2
import org.http4s.websocket.WebSocketFrame
import org.typelevel.ci.CIString
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir._
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.model.ServerResponse

import scala.reflect.ClassTag

class Http4sInvalidWebSocketUse(val message: String) extends Exception

trait Http4sServerInterpreter[F[_]] {
/** A capability that is used by endpoints, when they need to access the http4s-provided context. Such a requirement can be added using the
* [[RichHttp4sEndpoint.contextIn]] method.
*/
trait Context[T]

trait Http4sServerInterpreter[F[_]] {
implicit def fa: Async[F]

def http4sServerOptions: Http4sServerOptions[F] = Http4sServerOptions.default[F]
Expand All @@ -42,28 +46,62 @@ trait Http4sServerInterpreter[F[_]] {
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]
): WebSocketBuilder2[F] => HttpRoutes[F] = wsb => toRoutes(serverEndpoints, Some(wsb))

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
def toContextRoutes[T: ClassTag](se: ServerEndpoint[Fs2Streams[F] with Context[T], F]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], List(se), None)

def toContextRoutes[T: ClassTag](ses: List[ServerEndpoint[Fs2Streams[F] with Context[T], F]]): ContextRoutes[T, F] =
toContextRoutes(contextAttributeKey[T], ses, None)

private def createInterpreter[T](
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]]
): ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]] = {
implicit val monad: CatsMonadError[F] = new CatsMonadError[F]
implicit val bodyListener: BodyListener[F, Http4sResponseBody[F]] = new Http4sBodyListener[F]

val interpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, Http4sResponseBody[F], Fs2Streams[F]](
new ServerInterpreter(
FilterServerEndpoints(serverEndpoints),
new Http4sRequestBody[F](http4sServerOptions),
new Http4sToResponseBody[F](http4sServerOptions),
RejectInterceptor.disableWhenSingleEndpoint(http4sServerOptions.interceptors, serverEndpoints),
http4sServerOptions.deleteFile
)
}

private def toResponse[T](
interpreter: ServerInterpreter[Fs2Streams[F] with WebSockets with Context[T], F, Http4sResponseBody[F], Fs2Streams[F]],
serverRequest: Http4sServerRequest[F],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): OptionT[F, Response[F]] =
OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})

private def toRoutes(
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): HttpRoutes[F] = {
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (req: Request[F]) =>
val serverRequest = Http4sServerRequest(req)
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

OptionT(interpreter(serverRequest).flatMap {
case _: RequestResult.Failure => none.pure[F]
case RequestResult.Response(response) => serverResponseToHttp4s(response, webSocketBuilder).map(_.some)
})
private def toContextRoutes[T](
contextAttributeKey: AttributeKey[T],
serverEndpoints: List[ServerEndpoint[Fs2Streams[F] with WebSockets with Context[T], F]],
webSocketBuilder: Option[WebSocketBuilder2[F]]
): ContextRoutes[T, F] = {
val interpreter = createInterpreter(serverEndpoints)

Kleisli { (contextRequest: ContextRequest[F, T]) =>
val serverRequest = Http4sServerRequest(
contextRequest.req,
AttributeMap.Empty.put(contextAttributeKey, contextRequest.context)
)
toResponse(interpreter, serverRequest, webSocketBuilder)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ import org.http4s.EntityBody
import org.http4s.websocket.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.sse.ServerSentEvent
import sttp.tapir.{CodecFormat, StreamBodyIO, streamTextBody}
import sttp.tapir.model.ServerRequest
import sttp.tapir.typelevel.ParamConcat
import sttp.tapir.{AttributeKey, CodecFormat, Endpoint, StreamBodyIO, extractFromRequest, streamTextBody}

import java.nio.charset.Charset
import scala.reflect.ClassTag

package object http4s {
// either a web socket, or a stream with optional length (if known)
Expand All @@ -18,4 +21,31 @@ package object http4s {
streamTextBody(fs2Streams)(CodecFormat.TextEventStream(), Some(Charset.forName("UTF-8")))
.map(Http4sServerSentEvents.parseBytesToSSE[F])(Http4sServerSentEvents.serialiseSSEToBytes[F])
}

private[http4s] def contextAttributeKey[T: ClassTag]: AttributeKey[T] = new AttributeKey(implicitly[ClassTag[T]].runtimeClass.getName)

implicit class RichHttp4sEndpoint[A, I, E, O, R](e: Endpoint[A, I, E, O, R]) {

/** Access the context provided by an http4s middleware, such as authentication data.
*
* Interpreting endpoints which access the http4s context requires the usage of the [[Http4sServerInterpreter.toContextRoutes]]
* method. This then yields a [[org.http4s.ContextRoutes]] instance, which needs to be correctly mounted in the http4s router.
*
* Note that the correct syntax for adding the context input includes `()` after the method invocation, to properly infer types and
* capture implicit parameters, e.g. `myEndpoint.contextIn[Auth]()`.
*/
def contextIn[T]: AddContextInput[T] = new AddContextInput[T]

class AddContextInput[T] {
def apply[IT]()(implicit concat: ParamConcat.Aux[I, T, IT], ct: ClassTag[T]): Endpoint[A, IT, E, O, R with Context[T]] = {
val attribute = contextAttributeKey[T]
e.in(extractFromRequest[T] { (req: ServerRequest) =>
req
.attribute(attribute)
// should never happen since http4s had to build a ContextRequest with Ctx for ContextRoutes
.getOrElse(throw new RuntimeException(s"context ${attribute.typeName} not found in the request"))
})
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package sttp.tapir.server.http4s

import cats.data._
import cats.effect._
import cats.effect.unsafe.implicits.global
import cats.syntax.all._
import fs2.Pipe
import org.http4s.blaze.server.BlazeServerBuilder
import org.http4s.server.Router
import org.http4s.server.ContextMiddleware
import org.http4s.ContextRoutes
import org.http4s.HttpRoutes
import org.scalatest.OptionValues
import org.scalatest.matchers.should.Matchers._
import sttp.capabilities.WebSockets
Expand Down Expand Up @@ -50,6 +54,32 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi
}
.unsafeRunSync()
},
Test("should work with a router and context routes in a context") {
val expectedContext: String = "Hello World!" // the context we expect http4s to provide to the endpoint

val e: Endpoint[Unit, String, Unit, String, Context[String]] =
endpoint.get.in("test" / "router").contextIn[String]().out(stringBody)

val routesWithContext: ContextRoutes[String, IO] =
Http4sServerInterpreter[IO]()
// server logic is to return the context as is
.toContextRoutes(e.serverLogicSuccess(ctx => IO.pure(ctx)))

// middleware to add the context to each request (so here string constant)
val middleware: ContextMiddleware[IO, String] =
ContextMiddleware.const(expectedContext)

BlazeServerBuilder[IO]
.withExecutionContext(ExecutionContext.global)
.bindHttp(0, "localhost")
.withHttpApp(Router("/api" -> middleware(routesWithContext)).orNotFound)
.resource
.use { server =>
val port = server.address.getPort
basicRequest.get(uri"http://localhost:$port/api/test/router").send(backend).map(_.body shouldBe Right(expectedContext))
}
.unsafeRunSync()
},
createServerTest.testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain]
Expand Down