Skip to content

Commit

Permalink
Initial implementation of zio-http web-sockets support
Browse files Browse the repository at this point in the history
  • Loading branch information
yabosedira committed Aug 31, 2023
1 parent 40614d7 commit 97444d8
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import zio.stream.ZStream

import scala.util.{Failure, Success, Try}

private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] {
override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] =
private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] {
override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] =
ZIO
.environmentWithZIO[R]
.apply { r =>
body match {
case ZioStreamHttpResponseBody(stream, contentLength) =>
ZIO.succeed(
case Right(ZioStreamHttpResponseBody(stream, contentLength)) =>
ZIO.right(
ZioStreamHttpResponseBody(
stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream
.fromZIO(cb(Success(())))
Expand All @@ -22,7 +22,8 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi
contentLength
)
)
case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw)
case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw)
case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import sttp.monad.MonadError
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.ztapir._
import zio._
import zio.http.ChannelEvent.Read
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}

trait ZioHttpInterpreter[R] {
Expand All @@ -19,16 +21,16 @@ trait ZioHttpInterpreter[R] {

def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = {
implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2]
implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2]
val widenedSes = ses.map(_.widen[R & R2])
val widenedServerOptions = zioHttpServerOptions.widen[R & R2]
val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions)
val zioHttpResponseBody = new ZioHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes)
implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2]
val widenedSes = ses.map(_.widen[R & R2])
val widenedServerOptions = zioHttpServerOptions.widen[R & R2]
val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions)
val zioHttpResponseBody = new ZioHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes)

def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) =
Handler.fromZIO {
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams](
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams](
_ => filteredEndpoints,
zioHttpRequestBody,
zioHttpResponseBody,
Expand All @@ -42,27 +44,27 @@ trait ZioHttpInterpreter[R] {
error => ZIO.fail(error),
{
case RequestResult.Response(resp) =>
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = resp.body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty =>
ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
}
val statusCode = resp.code.code
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) =>
Handler.webSocket { channel =>
{
channel.receiveAll {
case ChannelEvent.Read(message) =>
for {
m <- body(message)
_ <- ZIO.foldLeft(m)(())((_, z) => channel.send(Read(z)))
} yield ()

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = resp.body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
case v =>
channel.send(v)
}
}
.getOrElse(Body.empty)
)
)
case RequestResult.Failure(_) =>
}.toResponse
}

case RequestResult.Failure(_) =>
ZIO.fail(
new RuntimeException(
s"The path: ${req.path} matches the shape of some endpoint, but none of the " +
Expand All @@ -76,11 +78,11 @@ trait ZioHttpInterpreter[R] {
}

val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes)
val singleEndpoint = widenedSes.size == 1
val singleEndpoint = widenedSes.size == 1

Http.fromOptionalHandlerZIO { request =>
// pre-filtering the endpoints by shape to determine, if this request should be handled by tapir
val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request))
val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request))
val filteredEndpoints2 = if (singleEndpoint) {
// If we are interpreting a single endpoint, we verify that the method matches as well; in case it doesn't,
// we refuse to handle the request, allowing other ZIO Http routes to handle it. Otherwise even if the method
Expand All @@ -98,6 +100,31 @@ trait ZioHttpInterpreter[R] {
}
}

private def handleHttpResponse(
resp: ServerResponse[ZioResponseBody],
body: Option[ZioHttpResponseBody]
) = {
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
}

private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] =
List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", ")))
}
Expand All @@ -107,7 +134,7 @@ object ZioHttpInterpreter {
new ZioHttpInterpreter[R] {
override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions
}
def apply(): ZioHttpInterpreter[Any] =
def apply(): ZioHttpInterpreter[Any] =
new ZioHttpInterpreter[Any] {
override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@ import zio.stream.ZStream
import java.nio.ByteBuffer
import java.nio.charset.Charset

class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] {
class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] {
override val streams: ZioStreams = ZioStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioHttpResponseBody =
rawValueToEntity(bodyType, v)
override def fromRawValue[R](
v: R,
headers: HasHeaders,
format: CodecFormat,
bodyType: RawBodyType[R]
): ZioResponseBody =
Right(rawValueToEntity(bodyType, v))

override def fromStreamValue(
v: streams.BinaryStream,
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None)
): ZioResponseBody = Right(ZioStreamHttpResponseBody(v, None))

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): ZioHttpResponseBody =
ZioStreamHttpResponseBody(ZStream.empty, None) // TODO
): ZioResponseBody =
Left(ZioWebSockets.pipeToBody(pipe, o))

private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = {
private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody =
bodyType match {
case RawBodyType.StringBody(charset) =>
val bytes = r.toString.getBytes(charset)
Expand Down Expand Up @@ -71,5 +76,4 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sttp.tapir.server.ziohttp
import sttp.capabilities.zio.ZioStreams
import sttp.capabilities.zio.ZioStreams.Pipe
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.WebSocketFrame
import zio.http.{WebSocketFrame => ZWebSocketFrame}
import zio.stream.ZStream
import zio.{Chunk, ZIO}

object ZioWebSockets {
def pipeToBody[REQ, RESP](
pipe: Pipe[REQ, RESP],
o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): F2F = { in =>
ZStream
.from(in)
.map(zFrameToFrame)
.map {
case WebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None
case f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
}
}
.collectWhileSome
.viaFunction(pipe)
.map(o.responses.encode)
.map(frameToZFrame)
.tap(v => zio.ZIO.succeed(println(v)))
.runFoldZIO(List.empty[ZWebSocketFrame])((s, ws) => ZIO.succeed(ws +: s))
}

private def zFrameToFrame(f: ZWebSocketFrame): WebSocketFrame =
f match {
case ZWebSocketFrame.Text(text) => WebSocketFrame.Text(text, f.isFinal, rsv = None)
case ZWebSocketFrame.Binary(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)
case ZWebSocketFrame.Continuation(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)
case ZWebSocketFrame.Ping => WebSocketFrame.ping
case ZWebSocketFrame.Pong => WebSocketFrame.pong
case ZWebSocketFrame.Close(status, reason) => WebSocketFrame.Close(status, reason.getOrElse(""))
case _ => WebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None)
}

private def frameToZFrame(f: WebSocketFrame): ZWebSocketFrame =
f match {
case WebSocketFrame.Text(p, finalFragment, _) => ZWebSocketFrame.Text(p, finalFragment)
case WebSocketFrame.Binary(p, finalFragment, _) => ZWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment)
case WebSocketFrame.Ping(_) => ZWebSocketFrame.Ping
case WebSocketFrame.Pong(_) => ZWebSocketFrame.Pong
case WebSocketFrame.Close(code, reason) => ZWebSocketFrame.Close(code, Some(reason))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sttp.tapir.server
import zio.Task
import zio.http.{WebSocketFrame => ZWebSocketFrame}

package object ziohttp {
type F2F = ZWebSocketFrame => Task[List[ZWebSocketFrame]]

type ZioResponseBody =
Either[F2F, ZioHttpResponseBody]

}
Loading

0 comments on commit 97444d8

Please sign in to comment.