Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Http4sWebSockets. #3393

Merged
merged 5 commits into from
Dec 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,114 +1,140 @@
package sttp.tapir.server.http4s

import cats.Monad
import cats.effect.Temporal
import cats.effect.std.Queue
import cats.{Applicative, Monad}
import cats.syntax.all._
import fs2._
import fs2.concurrent.Channel
import org.http4s.websocket.{WebSocketFrame => Http4sWebSocketFrame}
import scodec.bits.ByteVector
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.WebSocketFrame
import cats.effect.implicits._

private[http4s] object Http4sWebSockets {
def pipeToBody[F[_]: Temporal, REQ, RESP](
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = {
Queue.bounded[F, WebSocketFrame](2).map { pongs => (in: Stream[F, Http4sWebSocketFrame]) =>
val sttpFrames = in.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong)
val autoPongs = optionallyAutoPong(ignorePongs, pongs, o.autoPongOnPing)
val autoPings = o.autoPing match {
case Some((interval, frame)) => Stream.awakeEvery[F](interval).map(_ => frame)
case None => Stream.empty
}
val decodeClose = optionallyDecodeClose(autoPongs, o.decodeCloseRequests)
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = {
if ((!o.autoPongOnPing) && o.autoPing.isEmpty) {
// fast track: lift Http4sWebSocketFrames into REQ, run through pipe, convert RESP back to Http4sWebSocketFrame

(decodeClose
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
(in: Stream[F, Http4sWebSocketFrame]) =>
val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests)
val sttpFrames = decodeClose.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong)
ignorePongs
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
}
.through(pipe)
.mapChunks(_.map(r => frameToHttp4sFrame(o.responses.encode(r))))
.append(Stream(frameToHttp4sFrame(WebSocketFrame.close)))
}.pure[F] else {
// concurrently merge business logic response, autoPings, autoPongOnPing
// use fs2.Channel to perform the merge (more efficient than Stream#mergeHaltL / Stream#parJoin)

Channel.bounded[F, Chunk[Http4sWebSocketFrame]](64).map { c => (in: Stream[F, Http4sWebSocketFrame]) =>
val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests)
val sttpFrames = decodeClose.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong)
val autoPongs = optionallyAutoPong(ignorePongs, c, o.autoPongOnPing)
val autoPings = o.autoPing match {
case Some((interval, frame)) => (c.send(Chunk.singleton(frameToHttp4sFrame(frame))) >> Temporal[F].sleep(interval)).foreverM[Unit]
case None => Applicative[F].unit
}
.through(pipe)
.map(o.responses.encode)
.mergeHaltL(Stream.repeatEval(pongs.take))
.mergeHaltL(autoPings) ++ Stream(WebSocketFrame.close))
.map(frameToHttp4sFrame)

val outputProducer = autoPongs
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
}
.through(pipe)
.chunks
.foreach(chunk => c.send(chunk.map(r => frameToHttp4sFrame(o.responses.encode(r)))).void)
.compile
.drain

val outcomes = (outputProducer.guarantee(c.close.void), autoPings).parTupled.void

Stream
.bracket(Temporal[F].start(outcomes))(f => f.cancel >> f.joinWithUnit) >>
c.stream.append(Stream(Chunk.singleton(frameToHttp4sFrame(WebSocketFrame.close)))).unchunks
}
}
}

private def http4sFrameToFrame(f: Http4sWebSocketFrame): WebSocketFrame =
f match {
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case x: Http4sWebSocketFrame.Ping => WebSocketFrame.Ping(x.data.toArray)
case x: Http4sWebSocketFrame.Pong => WebSocketFrame.Pong(x.data.toArray)
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case x: Http4sWebSocketFrame.Ping => WebSocketFrame.Ping(x.data.toArray)
case x: Http4sWebSocketFrame.Pong => WebSocketFrame.Pong(x.data.toArray)
case c: Http4sWebSocketFrame.Close => WebSocketFrame.Close(c.closeCode, "")
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
}

private def frameToHttp4sFrame(w: WebSocketFrame): Http4sWebSocketFrame =
w match {
case x: WebSocketFrame.Text => Http4sWebSocketFrame.Text(x.payload, x.finalFragment)
case x: WebSocketFrame.Text => Http4sWebSocketFrame.Text(x.payload, x.finalFragment)
case x: WebSocketFrame.Binary => Http4sWebSocketFrame.Binary(ByteVector(x.payload), x.finalFragment)
case x: WebSocketFrame.Ping => Http4sWebSocketFrame.Ping(ByteVector(x.payload))
case x: WebSocketFrame.Pong => Http4sWebSocketFrame.Pong(ByteVector(x.payload))
case x: WebSocketFrame.Close => Http4sWebSocketFrame.Close(x.statusCode, x.reasonText).fold(throw _, identity)
case x: WebSocketFrame.Ping => Http4sWebSocketFrame.Ping(ByteVector(x.payload))
case x: WebSocketFrame.Pong => Http4sWebSocketFrame.Pong(ByteVector(x.payload))
case x: WebSocketFrame.Close => Http4sWebSocketFrame.Close(x.statusCode, x.reasonText).fold(throw _, identity)
}

private def optionallyConcatenateFrames[F[_]](s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = {
private def optionallyConcatenateFrames[F[_]](s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

s.mapAccumulate(None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collect { case (_, Some(f)) => f }
} else {
s
}
}
} else s

private def optionallyIgnorePong[F[_]](s: Stream[F, WebSocketFrame], doIgnore: Boolean): Stream[F, WebSocketFrame] = {
if (doIgnore) {
s.filter {
case _: WebSocketFrame.Pong => false
case _ => true
case _ => true
}
} else s
}

private def optionallyAutoPong[F[_]: Monad](
s: Stream[F, WebSocketFrame],
pongs: Queue[F, WebSocketFrame],
doAuto: Boolean
): Stream[F, WebSocketFrame] = {
private def optionallyAutoPong[F[_] : Monad](
s: Stream[F, WebSocketFrame],
c: Channel[F, Chunk[Http4sWebSocketFrame]],
doAuto: Boolean
): Stream[F, WebSocketFrame] =
if (doAuto) {
val trueF = true.pure[F]
s.evalFilter {
case ping: WebSocketFrame.Ping => pongs.offer(WebSocketFrame.Pong(ping.payload)).map(_ => false)
case _ => trueF
case ping: WebSocketFrame.Ping => c.send(Chunk.singleton(frameToHttp4sFrame(WebSocketFrame.Pong(ping.payload)))).map(_ => false)
case _ => trueF
}
} else s
}

private def optionallyDecodeClose[F[_]](s: Stream[F, WebSocketFrame], doDecodeClose: Boolean): Stream[F, WebSocketFrame] =
private def optionallyDecodeClose[F[_]](s: Stream[F, Http4sWebSocketFrame], doDecodeClose: Boolean): Stream[F, Http4sWebSocketFrame] =
if (!doDecodeClose) {
s.takeWhile {
case _: WebSocketFrame.Close => false
case _ => true
case _: Http4sWebSocketFrame.Close => false
case _ => true
}
} else s
}