Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of websocket support for ZIO http. #3147

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/server/ziohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ capability. Both response bodies and request bodies can be streamed. Usage: `str
The capability can be added to the classpath independently of the interpreter through the
`"com.softwaremill.sttp.shared" %% "zio"` dependency.

## Web sockets

The interpreter supports web sockets, with pipes of type `zio.stream.Stream[Throwable, REQ] => zio.stream.Stream[Throwable, RESP]`.
See [web sockets](../endpoint/websockets.md) for more details. It also supports auto-ping, auto-pong-on-ping, ignoring-pongs and handling
of fragmented frames.

## Error handling

By default, any endpoints interpreted with the `ZioHttpInterpreter` will use tapir's built-in failed effect handling,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import zio.stream.ZStream

import scala.util.{Failure, Success, Try}

private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] {
override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] =
private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] {
override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] =
ZIO
.environmentWithZIO[R]
.apply { r =>
body match {
case ZioStreamHttpResponseBody(stream, contentLength) =>
ZIO.succeed(
case Right(ZioStreamHttpResponseBody(stream, contentLength)) =>
ZIO.right(
ZioStreamHttpResponseBody(
stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream
.fromZIO(cb(Success(())))
Expand All @@ -22,7 +22,8 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi
contentLength
)
)
case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw)
case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw)
case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package sttp.tapir.server.ziohttp

import sttp.capabilities.WebSockets
import sttp.capabilities.zio.ZioStreams
import sttp.model.{Method, Header => SttpHeader}
import sttp.model.Method
import sttp.model.{Header => SttpHeader}
import sttp.monad.MonadError
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interceptor.reject.RejectInterceptor
import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter}
import sttp.tapir.server.interpreter.FilterServerEndpoints
import sttp.tapir.server.interpreter.ServerInterpreter
import sttp.tapir.server.model.ServerResponse
import sttp.tapir.ztapir._
import zio._
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}
import zio.http.{Header => ZioHttpHeader}
import zio.http.{Headers => ZioHttpHeaders}
import zio.http._

trait ZioHttpInterpreter[R] {
def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default

def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams]): HttpApp[R & R2, Throwable] =
def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams with WebSockets]): HttpApp[R & R2, Throwable] =
toHttp(List(se))

def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams]]): HttpApp[R & R2, Throwable] = {
def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = {
implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2]
implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2]
val widenedSes = ses.map(_.widen[R & R2])
Expand All @@ -25,9 +31,9 @@ trait ZioHttpInterpreter[R] {
val zioHttpResponseBody = new ZioHttpToResponseBody
val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes)

def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams]]) = {
def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) =
Handler.fromZIO {
val interpreter = new ServerInterpreter[ZioStreams, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams](
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams](
_ => filteredEndpoints,
zioHttpRequestBody,
zioHttpResponseBody,
Expand All @@ -41,26 +47,12 @@ trait ZioHttpInterpreter[R] {
error => ZIO.fail(error),
{
case RequestResult.Response(resp) =>
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = resp.body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty =>
ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body)
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = resp.body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)

case RequestResult.Failure(_) =>
ZIO.fail(
new RuntimeException(
Expand All @@ -73,9 +65,8 @@ trait ZioHttpInterpreter[R] {
}
)
}
}

val serverEndpointsFilter = FilterServerEndpoints[ZioStreams, RIO[R & R2, *]](widenedSes)
val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes)
val singleEndpoint = widenedSes.size == 1

Http.fromOptionalHandlerZIO { request =>
Expand All @@ -98,19 +89,54 @@ trait ZioHttpInterpreter[R] {
}
}

private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = {
Handler.webSocket { channel =>
for {
channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent]
messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork
webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue))
_ <- webSocketStream.mapZIO(channel.send).runDrain
} yield messageReceptionFiber.join
}.toResponse
}

private def handleHttpResponse(
resp: ServerResponse[ZioResponseBody],
body: Option[ZioHttpResponseBody]
) = {
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
}

private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] =
List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", ")))
}

object ZioHttpInterpreter {
def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = {

def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] =
new ZioHttpInterpreter[R] {
override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions
}
}
def apply(): ZioHttpInterpreter[Any] = {
def apply(): ZioHttpInterpreter[Any] =
new ZioHttpInterpreter[Any] {
override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@ import zio.stream.ZStream
import java.nio.ByteBuffer
import java.nio.charset.Charset

class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] {
class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] {
override val streams: ZioStreams = ZioStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioHttpResponseBody =
rawValueToEntity(bodyType, v)
override def fromRawValue[R](
v: R,
headers: HasHeaders,
format: CodecFormat,
bodyType: RawBodyType[R]
): ZioResponseBody =
Right(rawValueToEntity(bodyType, v))

override def fromStreamValue(
v: streams.BinaryStream,
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None)
): ZioResponseBody = Right(ZioStreamHttpResponseBody(v, None))

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): ZioHttpResponseBody =
ZioStreamHttpResponseBody(ZStream.empty, None) // TODO
): ZioResponseBody =
Left(ZioWebSockets.pipeToBody(pipe, o))

private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = {
private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody =
bodyType match {
case RawBodyType.StringBody(charset) =>
val bytes = r.toString.getBytes(charset)
Expand Down Expand Up @@ -71,5 +76,4 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package sttp.tapir.server.ziohttp
import sttp.capabilities.zio.ZioStreams
import sttp.capabilities.zio.ZioStreams.Pipe
import sttp.tapir.DecodeResult
import sttp.tapir.WebSocketBodyOutput
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.ws.{WebSocketFrame => SttpWebSocketFrame}
import zio.Chunk
import zio.Duration.fromScala
import zio.Schedule
import zio.ZIO
import zio.http.ChannelEvent.Read
import zio.http.WebSocketChannelEvent
import zio.http.{WebSocketFrame => ZioWebSocketFrame}
import zio.stream
import zio.stream.ZStream

import scala.concurrent.duration.FiniteDuration

object ZioWebSockets {

def pipeToBody[REQ, RESP](
pipe: Pipe[REQ, RESP],
o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): WebSocketHandler = {
{ (in: stream.Stream[Throwable, WebSocketChannelEvent]) =>
{
for {
pongs <- zio.Queue.bounded[SttpWebSocketFrame](1)
sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome
concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames)
ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong)
autoPongs = optionallyAutoPongOnPing(ignoredPongs, pongs, o.autoPongOnPing)
autoPing = optionallyAutoPing(o.autoPing)
closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close)
intermediateStream = autoPongs
.map {
case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None
case f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
}
}
.collectWhileSome
.viaFunction(pipe)
.map(o.responses.encode)
.mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1))
.mergeHaltLeft(autoPing) ++ closeStream
sendReceiveStream = intermediateStream.map(frameToZWebSocketChannelEvent)
} yield sendReceiveStream
}
}
}

private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] =
channelEvent match {
case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping)
case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong)
case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse("")))
case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None))
case _ => None
}

private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent =
f match {
case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment))
case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment))
case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping)
case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong)
case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason)))
}

private def optionallyIgnorePongs(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
ignorePong: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
sttpFrames
.filter {
case _: SttpWebSocketFrame.Pong if ignorePong => false
case _ => true
}
}

private def optionallyAutoPing(
autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)]
): ZStream[Any, Nothing, SttpWebSocketFrame] = {
autoPing match {
case Some((duration, ping)) =>
stream.ZStream
.from(ping)
.repeat(Schedule.fixed(fromScala(duration)))
case None => stream.ZStream.empty
}
}

private def optionallyAutoPongOnPing(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
pongs: zio.Queue[SttpWebSocketFrame],
autoPongOnPing: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (autoPongOnPing) {
sttpFrames.mapZIO {
case SttpWebSocketFrame.Ping(payload) if autoPongOnPing =>
pongs.offer(SttpWebSocketFrame.Pong(payload)).as(Option.empty[SttpWebSocketFrame])
case f => ZIO.succeed(Some(f))
}.collectSome
} else sttpFrames
}

private def optionallyConcatenate(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
concatenate: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (concatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

sttpFrames
.mapAccum(None: Accumulator) {
case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(Some(Right(acc + f.payload)), None)

case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}
.collectSome
} else sttpFrames
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package sttp.tapir.server
import zio.http.WebSocketChannelEvent
import zio.{ZIO, stream}

package object ziohttp {
type WebSocketHandler =
stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]]

type ZioResponseBody =
Either[WebSocketHandler, ZioHttpResponseBody]

}
Loading