Skip to content

Commit

Permalink
Merge pull request #3147 from yabosedira/zio-http-websocket
Browse files Browse the repository at this point in the history
Initial implementation of websocket support for ZIO http.
  • Loading branch information
adamw authored Sep 19, 2023
2 parents 373815c + c5933c9 commit ec456ed
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 48 deletions.
6 changes: 6 additions & 0 deletions doc/server/ziohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ capability. Both response bodies and request bodies can be streamed. Usage: `str
The capability can be added to the classpath independently of the interpreter through the
`"com.softwaremill.sttp.shared" %% "zio"` dependency.

## Web sockets

The interpreter supports web sockets, with pipes of type `zio.stream.Stream[Throwable, REQ] => zio.stream.Stream[Throwable, RESP]`.
See [web sockets](../endpoint/websockets.md) for more details. It also supports auto-ping, auto-pong-on-ping, ignoring-pongs and handling
of fragmented frames.

## Error handling

By default, any endpoints interpreted with the `ZioHttpInterpreter` will use tapir's built-in failed effect handling,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import zio.stream.ZStream

import scala.util.{Failure, Success, Try}

private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] {
override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] =
private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] {
override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] =
ZIO
.environmentWithZIO[R]
.apply { r =>
body match {
case ZioStreamHttpResponseBody(stream, contentLength) =>
ZIO.succeed(
case Right(ZioStreamHttpResponseBody(stream, contentLength)) =>
ZIO.right(
ZioStreamHttpResponseBody(
stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream
.fromZIO(cb(Success(())))
Expand All @@ -22,7 +22,8 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi
contentLength
)
)
case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw)
case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw)
case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package sttp.tapir.server.ziohttp

import sttp.capabilities.WebSockets
import sttp.capabilities.zio.ZioStreams
import sttp.model.{Method, Header => SttpHeader}
import sttp.model.Method
import sttp.model.{Header => SttpHeader}
import sttp.monad.MonadError
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.interpreter.FilterServerEndpoints
import sttp.tapir.server.interpreter.ServerInterpreter
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.ztapir._
import zio._
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}
import zio.http.{Header => ZioHttpHeader}
import zio.http.{Headers => ZioHttpHeaders}
import zio.http._

trait ZioHttpInterpreter[R] {
def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default

def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams]): HttpApp[R & R2, Throwable] =
def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams with WebSockets]): HttpApp[R & R2, Throwable] =
toHttp(List(se))

def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams]]): HttpApp[R & R2, Throwable] = {
def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = {
implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2]
implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2]
val widenedSes = ses.map(_.widen[R & R2])
Expand All @@ -25,9 +31,9 @@ trait ZioHttpInterpreter[R] {
val zioHttpResponseBody = new ZioHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes)

def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams]]) = {
def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) =
Handler.fromZIO {
val interpreter = new ServerInterpreter[ZioStreams, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams](
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams](
_ => filteredEndpoints,
zioHttpRequestBody,
zioHttpResponseBody,
Expand All @@ -41,26 +47,12 @@ trait ZioHttpInterpreter[R] {
error => ZIO.fail(error),
{
case RequestResult.Response(resp) =>
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = resp.body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty =>
ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body)
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = resp.body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)

case RequestResult.Failure(_) =>
ZIO.fail(
new RuntimeException(
Expand All @@ -73,9 +65,8 @@ trait ZioHttpInterpreter[R] {
}
)
}
}

val serverEndpointsFilter = FilterServerEndpoints[ZioStreams, RIO[R & R2, *]](widenedSes)
val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes)
val singleEndpoint = widenedSes.size == 1

Http.fromOptionalHandlerZIO { request =>
Expand All @@ -98,19 +89,54 @@ trait ZioHttpInterpreter[R] {
}
}

private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = {
Handler.webSocket { channel =>
for {
channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent]
messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork
webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue))
_ <- webSocketStream.mapZIO(channel.send).runDrain
} yield messageReceptionFiber.join
}.toResponse
}

private def handleHttpResponse(
resp: ServerResponse[ZioResponseBody],
body: Option[ZioHttpResponseBody]
) = {
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
}

private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] =
List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", ")))
}

object ZioHttpInterpreter {
def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = {

def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] =
new ZioHttpInterpreter[R] {
override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions
}
}
def apply(): ZioHttpInterpreter[Any] = {
def apply(): ZioHttpInterpreter[Any] =
new ZioHttpInterpreter[Any] {
override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@ import zio.stream.ZStream
import java.nio.ByteBuffer
import java.nio.charset.Charset

class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] {
class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] {
override val streams: ZioStreams = ZioStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioHttpResponseBody =
rawValueToEntity(bodyType, v)
override def fromRawValue[R](
v: R,
headers: HasHeaders,
format: CodecFormat,
bodyType: RawBodyType[R]
): ZioResponseBody =
Right(rawValueToEntity(bodyType, v))

override def fromStreamValue(
v: streams.BinaryStream,
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None)
): ZioResponseBody = Right(ZioStreamHttpResponseBody(v, None))

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): ZioHttpResponseBody =
ZioStreamHttpResponseBody(ZStream.empty, None) // TODO
): ZioResponseBody =
Left(ZioWebSockets.pipeToBody(pipe, o))

private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = {
private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody =
bodyType match {
case RawBodyType.StringBody(charset) =>
val bytes = r.toString.getBytes(charset)
Expand Down Expand Up @@ -71,5 +76,4 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package sttp.tapir.server.ziohttp
import sttp.capabilities.zio.ZioStreams
import sttp.capabilities.zio.ZioStreams.Pipe
import sttp.tapir.DecodeResult
import sttp.tapir.WebSocketBodyOutput
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.ws.{WebSocketFrame => SttpWebSocketFrame}
import zio.Chunk
import zio.Duration.fromScala
import zio.Schedule
import zio.ZIO
import zio.http.ChannelEvent.Read
import zio.http.WebSocketChannelEvent
import zio.http.{WebSocketFrame => ZioWebSocketFrame}
import zio.stream
import zio.stream.ZStream

import scala.concurrent.duration.FiniteDuration

object ZioWebSockets {

def pipeToBody[REQ, RESP](
pipe: Pipe[REQ, RESP],
o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): WebSocketHandler = {
{ (in: stream.Stream[Throwable, WebSocketChannelEvent]) =>
{
for {
pongs <- zio.Queue.bounded[SttpWebSocketFrame](1)
sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome
concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames)
ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong)
autoPongs = optionallyAutoPongOnPing(ignoredPongs, pongs, o.autoPongOnPing)
autoPing = optionallyAutoPing(o.autoPing)
closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close)
intermediateStream = autoPongs
.map {
case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None
case f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
}
}
.collectWhileSome
.viaFunction(pipe)
.map(o.responses.encode)
.mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1))
.mergeHaltLeft(autoPing) ++ closeStream
sendReceiveStream = intermediateStream.map(frameToZWebSocketChannelEvent)
} yield sendReceiveStream
}
}
}

private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] =
channelEvent match {
case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping)
case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong)
case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse("")))
case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None))
case _ => None
}

private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent =
f match {
case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment))
case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment))
case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping)
case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong)
case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason)))
}

private def optionallyIgnorePongs(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
ignorePong: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
sttpFrames
.filter {
case _: SttpWebSocketFrame.Pong if ignorePong => false
case _ => true
}
}

private def optionallyAutoPing(
autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)]
): ZStream[Any, Nothing, SttpWebSocketFrame] = {
autoPing match {
case Some((duration, ping)) =>
stream.ZStream
.from(ping)
.repeat(Schedule.fixed(fromScala(duration)))
case None => stream.ZStream.empty
}
}

private def optionallyAutoPongOnPing(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
pongs: zio.Queue[SttpWebSocketFrame],
autoPongOnPing: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (autoPongOnPing) {
sttpFrames.mapZIO {
case SttpWebSocketFrame.Ping(payload) if autoPongOnPing =>
pongs.offer(SttpWebSocketFrame.Pong(payload)).as(Option.empty[SttpWebSocketFrame])
case f => ZIO.succeed(Some(f))
}.collectSome
} else sttpFrames
}

private def optionallyConcatenate(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
concatenate: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (concatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

sttpFrames
.mapAccum(None: Accumulator) {
case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(Some(Right(acc + f.payload)), None)

case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}
.collectSome
} else sttpFrames
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package sttp.tapir.server
import zio.http.WebSocketChannelEvent
import zio.{ZIO, stream}

package object ziohttp {
type WebSocketHandler =
stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]]

type ZioResponseBody =
Either[WebSocketHandler, ZioHttpResponseBody]

}
Loading

0 comments on commit ec456ed

Please sign in to comment.