Skip to content

Commit

Permalink
Using queues to receive and process websocket messages
Browse files Browse the repository at this point in the history
  • Loading branch information
yabosedira committed Sep 16, 2023
1 parent 806e208 commit 33c4b5e
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ trait ZioHttpInterpreter[R] {
resp.body match {
case None => handleHttpResponse(resp, None)
case Some(Right(body)) => handleHttpResponse(resp, Some(body))
case Some(Left(body)) => handleWebSocketResponse(body)
case Some(Left(body)) =>
println(body)
handleWebSocketResponse(body)
}

case RequestResult.Failure(_) =>
Expand Down Expand Up @@ -85,8 +87,15 @@ trait ZioHttpInterpreter[R] {
}
}

private def handleWebSocketResponse(webSocketHandler: WebSocketHandler) = {
Handler.webSocket(webSocketHandler).toResponse
private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = {
Handler.webSocket { channel =>
for {
channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent]
messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork
webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue))
_ <- webSocketStream.mapZIO(channel.send).runDrain
} yield messageReceptionFiber.join
}.toResponse
}

private def handleHttpResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,120 +4,129 @@ import sttp.capabilities.zio.ZioStreams.Pipe
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.{WebSocketFrame => SttpWebSocketFrame}
import zio.Duration.fromScala
import zio.http.ChannelEvent.Read
import zio.http.{ChannelEvent, WebSocketChannel, WebSocketChannelEvent, WebSocketFrame => ZioWebSocketFrame}
import zio.http.{WebSocketChannel, WebSocketChannelEvent, WebSocketFrame => ZioWebSocketFrame}
import zio.stream.ZStream
import zio.{Chunk, Task, ZIO, stream}
import zio.{Chunk, Schedule, ZIO, stream}

object ZioWebSockets {
private val NormalClosureStatusCode = 1000
private val AbnormalClosureStatusCode = 1006

def pipeToBody[REQ, RESP](pipe: Pipe[REQ, RESP], o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]): WebSocketHandler =
channel => {
val reqToRespPipeline = reqToResp(pipe, o)
import scala.concurrent.duration.FiniteDuration

channel.receiveAll {
case ChannelEvent.Read(ZioWebSocketFrame.Ping) if o.autoPongOnPing =>
channel.send(Read(ZioWebSocketFrame.Pong))
case ChannelEvent.Read(message) if message.isFinal =>
processWebSocketFrame(channel, reqToRespPipeline, message)
case ChannelEvent.Read(message) => // Fragmented message
for {
message <- accumulateFrames(channel, message)
response <- processWebSocketFrame(channel, reqToRespPipeline, message)
} yield response
case ChannelEvent.Unregistered =>
channel.send(Read(ZioWebSocketFrame.close(NormalClosureStatusCode)))
case _ =>
ZIO.unit
}
}
object ZioWebSockets {

private def reqToResp[REQ, RESP](
def pipeToBody[REQ, RESP](
pipe: Pipe[REQ, RESP],
o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame] = {
in: stream.Stream[Throwable, ZioWebSocketFrame] =>
in
.map(zFrameToFrame)
.map {
case SttpWebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None
case SttpWebSocketFrame.Pong(_) if o.ignorePong => None
case f: SttpWebSocketFrame =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
): WebSocketHandler = {
{ (in: stream.Stream[Throwable, WebSocketChannelEvent]) =>
{
for {
pongs <- zio.Queue.bounded[SttpWebSocketFrame](1)
sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome
concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames)
ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong)
autoPongs = optionallyAutoPong(ignoredPongs, pongs, o.autoPongOnPing)
autoPing = optionallyAutoPing(o.autoPing)
closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close)
interStream = autoPongs
.map {
case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None
case f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
}
}
}
.collectWhileSome
.viaFunction(pipe)
.map(o.responses.encode)
.map(frameToZFrame)
.collectWhileSome
.viaFunction(pipe)
.map(o.responses.encode)
.mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1))
.mergeHaltLeft(autoPing) ++ closeStream
r = interStream.map(frameToZWebSocketChannelEvent)
} yield r
}
}
}

private def processWebSocketFrame(
channel: WebSocketChannel,
body: stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame],
message: ZioWebSocketFrame
) = {
ZStream
.from(message)
.viaFunction(body)
.mapZIO(wsf => channel.send(Read(wsf)))
.runDrain
private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] =
channelEvent match {
case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None))
case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping)
case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong)
case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse("")))
case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None))
case _ => None
}

private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent =
f match {
case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment))
case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment))
case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping)
case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong)
case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason)))
}

private def optionallyIgnorePongs(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
ignorePong: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
sttpFrames
.filter {
case _: SttpWebSocketFrame.Pong if ignorePong => false
case _ => true
}
}

private def accumulateFrames(channel: WebSocketChannel, webSocketFrame: ZioWebSocketFrame): Task[ZioWebSocketFrame] = {
ZIO.iterate(webSocketFrame)(!_.isFinal) { wsf =>
for {
channelEvent <- channel.receive
accumulatedWebSocketFrame <- handleChannelEvent(channel, channelEvent, wsf)
} yield accumulatedWebSocketFrame
private def optionallyAutoPing(
autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)]
): ZStream[Any, Nothing, SttpWebSocketFrame] = {
autoPing match {
case Some((duration, ping)) =>
stream.ZStream
.from(ping)
.repeat(Schedule.fixed(fromScala(duration)))
case None => stream.ZStream.empty
}
}

private def handleChannelEvent(
channel: WebSocketChannel,
channelEvent: WebSocketChannelEvent,
acc: ZioWebSocketFrame
): Task[ZioWebSocketFrame] = {
channelEvent match {
case ChannelEvent.ExceptionCaught(cause) =>
channel.send(Read(ZioWebSocketFrame.close(AbnormalClosureStatusCode, Some(cause.getMessage)))) *>
channel.shutdown.map(_ => acc)
case Read(ZioWebSocketFrame.Continuation(newBuffer)) =>
acc match {
case b @ ZioWebSocketFrame.Binary(bytes) => ZIO.succeed(b.copy(bytes ++ newBuffer))
case t @ ZioWebSocketFrame.Text(text) => ZIO.succeed(t.copy(text + new String(newBuffer.toArray)))
case ZioWebSocketFrame.Close(status, reason) =>
ZIO.fail(new RuntimeException(s"Received unexpected close frame: $status, $reason"))
case ZioWebSocketFrame.Continuation(buffer) =>
channel.send(Read(ZioWebSocketFrame.Continuation(buffer))).map(_ => acc)
case ZioWebSocketFrame.Ping => channel.send(Read(ZioWebSocketFrame.Pong)).map(_ => acc)
case ZioWebSocketFrame.Pong => channel.send(Read(ZioWebSocketFrame.Ping)).map(_ => acc)
case _ => ZIO.succeed(acc)
private def optionallyAutoPong(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
pongs: zio.Queue[SttpWebSocketFrame],
autoPongOnPing: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (autoPongOnPing) {
sttpFrames.mapZIO {
case _: SttpWebSocketFrame.Ping if autoPongOnPing =>
pongs.offer(SttpWebSocketFrame.pong).as(Option.empty[SttpWebSocketFrame])
case f => ZIO.succeed(Some(f))
}.collectSome
} else sttpFrames
}

private def optionallyConcatenate(
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame],
concatenate: Boolean
): ZStream[Any, Throwable, SttpWebSocketFrame] = {
if (concatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

sttpFrames
.mapAccum(None: Accumulator) {
case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}
case _ => ZIO.succeed(acc)
}
.collectSome
} else sttpFrames
}
private def zFrameToFrame(f: ZioWebSocketFrame): SttpWebSocketFrame =
f match {
case ZioWebSocketFrame.Text(text) => SttpWebSocketFrame.Text(text, f.isFinal, rsv = None)
case ZioWebSocketFrame.Binary(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)
case ZioWebSocketFrame.Continuation(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)
case ZioWebSocketFrame.Ping => SttpWebSocketFrame.ping
case ZioWebSocketFrame.Pong => SttpWebSocketFrame.pong
case ZioWebSocketFrame.Close(status, reason) => SttpWebSocketFrame.Close(status, reason.getOrElse(""))
case _ => SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None)
}

private def frameToZFrame(f: SttpWebSocketFrame): ZioWebSocketFrame =
f match {
case SttpWebSocketFrame.Text(p, finalFragment, _) => ZioWebSocketFrame.Text(p, finalFragment)
case SttpWebSocketFrame.Binary(p, finalFragment, _) => ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment)
case SttpWebSocketFrame.Ping(_) => ZioWebSocketFrame.Ping
case SttpWebSocketFrame.Pong(_) => ZioWebSocketFrame.Pong
case SttpWebSocketFrame.Close(code, reason) => ZioWebSocketFrame.Close(code, Some(reason))
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package sttp.tapir.server
import zio.Task
import zio.http.WebSocketChannel
import zio.http.WebSocketChannelEvent
import zio.{ZIO, stream}

package object ziohttp {
type WebSocketHandler = WebSocketChannel => Task[Unit]
type WebSocketHandler =
stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]]

type ZioResponseBody =
Either[WebSocketHandler, ZioHttpResponseBody]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ package sttp.tapir.server.ziohttp

import cats.effect.{IO, Resource}
import io.netty.channel.{ChannelFactory, EventLoopGroup, ServerChannel}
import org.scalatest.{Assertion, Exceptional, FutureOutcome}
import org.scalatest.matchers.should.Matchers._
import org.scalatest.{Assertion, Exceptional, FutureOutcome}
import sttp.capabilities.zio.ZioStreams
import sttp.client3._
import sttp.client3.testing.SttpBackendStub
import sttp.model.MediaType
import sttp.monad.MonadError
import sttp.tapir.{PublicEndpoint, _}
import sttp.tapir.server.stub.TapirStubInterpreter
import sttp.tapir.server.tests._
import sttp.tapir.tests.{Test, TestSuite}
import sttp.tapir.ztapir.{RIOMonadError, RichZEndpoint}
import zio.{Promise, Ref, Runtime, Task, UIO, Unsafe, ZEnvironment, ZIO, ZLayer}
import zio.http.{HttpAppMiddleware, Path, Request, URL}
import sttp.tapir.{PublicEndpoint, _}
import zio.http.netty.{ChannelFactories, ChannelType, EventLoopGroups}
import zio.http.{HttpAppMiddleware, Path, Request, URL}
import zio.interop.catz._
import zio.stream.{ZPipeline, ZStream}
import zio.{Promise, Ref, Runtime, Task, UIO, Unsafe, ZEnvironment, ZIO, ZLayer}

import java.nio.charset.Charset
import java.time
Expand Down Expand Up @@ -173,33 +173,33 @@ class ZioHttpServerTest extends TestSuite {

implicit val m: MonadError[Task] = new RIOMonadError[Any]

new ServerBasicTests(
createServerTest,
interpreter,
multipleValueHeaderSupport = false,
supportsUrlEncodedPathSegments = false,
supportsMultipleSetCookieHeaders = false,
invulnerableToUnsanitizedHeaders = false
).tests() ++
// TODO: re-enable static content once a newer zio http is available. Currently these tests often fail with:
// Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE]
new AllServerTests(
createServerTest,
interpreter,
backend,
basic = false,
staticContent = false,
multipart = false,
file = false,
options = false
).tests() ++
new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
new ZioHttpCompositionTest(createServerTest).tests() ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty
}.tests() ++
additionalTests()
// new ServerBasicTests(
// createServerTest,
// interpreter,
// multipleValueHeaderSupport = false,
// supportsUrlEncodedPathSegments = false,
// supportsMultipleSetCookieHeaders = false,
// invulnerableToUnsanitizedHeaders = false
// ).tests() ++
// TODO: re-enable static content once a newer zio http is available. Currently these tests often fail with:
// Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [f2 content], parser state [STATUS_LINE]
// new AllServerTests(
// createServerTest,
// interpreter,
// backend,
// basic = false,
// staticContent = false,
// multipart = false,
// file = false,
// options = false
// ).tests() ++
// new ServerStreamingTests(createServerTest, ZioStreams).tests() ++
// new ZioHttpCompositionTest(createServerTest).tests() ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty
}.tests()
// additionalTests()
}
}
}

0 comments on commit 33c4b5e

Please sign in to comment.