Skip to content

Commit

Permalink
Updated Vertx Cats WebSocket (softwaremill#3573)
Browse files Browse the repository at this point in the history
  • Loading branch information
DybekK authored Mar 12, 2024
1 parent 87cb412 commit bd97fe2
Showing 1 changed file with 76 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package sttp.tapir.server.vertx.cats.streams

import _root_.fs2.{Chunk, Stream}
import cats.effect.kernel.Async
import _root_.fs2.concurrent.Channel
import cats.effect.kernel.Resource.ExitCase.{Canceled, Errored, Succeeded}
import cats.effect.{Deferred, Ref}
import cats.effect.{Deferred, Ref, Sync, Async, GenSpawn}
import cats.syntax.all._
import cats.effect.implicits._
import io.vertx.core.Handler
import io.vertx.core.buffer.Buffer
import io.vertx.core.streams.ReadStream
Expand All @@ -29,7 +30,9 @@ object fs2 {
dfd.get
}

implicit def fs2ReadStreamCompatible[F[_]](opts: VertxCatsServerOptions[F])(implicit F: Async[F]): ReadStreamCompatible[Fs2Streams[F]] = {
implicit def fs2ReadStreamCompatible[F[_]: Async](
opts: VertxCatsServerOptions[F]
): ReadStreamCompatible[Fs2Streams[F]] = {
new ReadStreamCompatible[Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]

Expand All @@ -41,13 +44,13 @@ object fs2 {
for {
promise <- Deferred[F, Unit]
state <- Ref.of(StreamState.empty[F, O](promise))
_ <- F.start(
_ <- GenSpawn[F].start(
stream
.evalMap({ chunk =>
val buffer = fn(chunk)
state.get.flatMap {
case StreamState(None, handler, _, _) =>
F.delay(handler.handle(buffer))
Sync[F].delay(handler.handle(buffer))
case StreamState(Some(promise), _, _, _) =>
for {
_ <- promise.get
Expand All @@ -60,15 +63,15 @@ object fs2 {
.onFinalizeCase({
case Succeeded =>
state.get.flatMap { state =>
F.delay(state.endHandler.handle(null))
Sync[F].delay(state.endHandler.handle(null))
}
case Canceled =>
state.get.flatMap { state =>
F.delay(state.errorHandler.handle(new Exception("Cancelled!")))
Sync[F].delay(state.errorHandler.handle(new Exception("Cancelled!")))
}
case Errored(cause) =>
state.get.flatMap { state =>
F.delay(state.errorHandler.handle(cause))
Sync[F].delay(state.errorHandler.handle(cause))
}
})
.compile
Expand Down Expand Up @@ -136,7 +139,7 @@ object fs2 {
} yield result
}

_ <- F.start(
_ <- GenSpawn[F].start(
Stream
.unfoldEval[F, Unit, ActivationEvent](())({ _ =>
for {
Expand All @@ -151,8 +154,8 @@ object fs2 {
} yield result.map((_, ()))
})
.evalMap({
case Pause => F.delay(readStream.pause())
case Resume => F.delay(readStream.resume())
case Pause => Sync[F].delay(readStream.pause())
case Resume => Sync[F].delay(readStream.resume())
})
.compile
.drain
Expand All @@ -173,36 +176,76 @@ object fs2 {
}
}

override def webSocketPipe[REQ, RESP](
readStream: ReadStream[WebSocketFrame],
private def decodeFrame[REQ, RESP](
frame: WebSocketFrame,
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): REQ = {
o.requests.decode(frame) match {
case DecodeResult.Value(v) => v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(frame, failure)
}
}

private def fastPath[REQ, RESP](
in: Stream[F, WebSocketFrame],
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): ReadStream[WebSocketFrame] = {
val stream0 = fromReadStreamInternal(readStream)
val stream1 = optionallyContatenateFrames(stream0, o.concatenateFragmentedFrames)
val stream2 = optionallyIgnorePong(stream1, o.ignorePong)
val autoPings = o.autoPing match {
case Some((interval, frame)) =>
Stream.awakeEvery(interval).as(frame)
case None =>
Stream.empty
}
val contatenatedFrames = optionallyContatenateFrames(in, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(contatenatedFrames, o.ignorePong)

val stream3 = stream2
.map { frame =>
o.requests.decode(frame) match {
case DecodeResult.Value(v) =>
v
case failure: DecodeResult.Failure =>
throw new WebSocketFrameDecodeFailure(frame, failure)
}
}
val stream = ignorePongs
.map(decodeFrame(_, o))
.through(pipe)
.map(o.responses.encode)
.mergeHaltL(autoPings)
.append(Stream(WebSocketFrame.close))

mapToReadStream[WebSocketFrame, WebSocketFrame](stream3, identity)
mapToReadStream[WebSocketFrame, WebSocketFrame](stream, identity)
}

private def standardPath[REQ, RESP](
in: Stream[F, WebSocketFrame],
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): ReadStream[WebSocketFrame] = {
val mergedStream = Channel.bounded[F, Chunk[WebSocketFrame]](64).map { c =>
val contatenatedFrames = optionallyContatenateFrames(in, o.concatenateFragmentedFrames)
val ignoredPongs = optionallyIgnorePong(contatenatedFrames, o.ignorePong)

val autoPings = o.autoPing match {
case Some((interval, frame)) => Stream.awakeEvery(interval).as(frame)
case None => Stream.empty
}

val outputProducer = ignoredPongs
.map(decodeFrame(_, o))
.through(pipe)
.chunks
.foreach(chunk => c.send(chunk.map(r => o.responses.encode(r))).void)
.compile
.drain

val outcomes = (outputProducer.guarantee(c.close.void), autoPings.compile.drain).parTupled.void

Stream.bracket(outcomes.start)(f => f.cancel >> f.joinWithUnit) >>
c.stream.append(Stream(Chunk.singleton(WebSocketFrame.close))).unchunks
}

mapToReadStream[WebSocketFrame, WebSocketFrame](Stream.force(mergedStream), identity)
}

override def webSocketPipe[REQ, RESP](
readStream: ReadStream[WebSocketFrame],
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): ReadStream[WebSocketFrame] = {
val stream = fromReadStreamInternal(readStream)

if ((!o.autoPongOnPing) && o.autoPing.isEmpty) {
fastPath(stream, pipe, o)
} else {
standardPath(stream, pipe, o)
}
}

def optionallyContatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
Expand Down

0 comments on commit bd97fe2

Please sign in to comment.