Skip to content

Commit

Permalink
Merge pull request #3340 from kamilkloch/http4s-websockets
Browse files Browse the repository at this point in the history
Improve Http4sWebSockets.pipeToBody
  • Loading branch information
adamw authored Nov 24, 2023
2 parents fdfa3da + 3859797 commit d54927e
Showing 1 changed file with 31 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ private[http4s] object Http4sWebSockets {
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = {
Queue.bounded[F, WebSocketFrame](1).map { pongs => (in: Stream[F, Http4sWebSocketFrame]) =>
Queue.bounded[F, WebSocketFrame](2).map { pongs => (in: Stream[F, Http4sWebSocketFrame]) =>
val sttpFrames = in.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong)
Expand All @@ -26,17 +26,15 @@ private[http4s] object Http4sWebSockets {
case Some((interval, frame)) => Stream.awakeEvery[F](interval).map(_ => frame)
case None => Stream.empty
}
val decodeClose = optionallyDecodeClose(autoPongs, o.decodeCloseRequests)

(autoPongs
.map {
case _: WebSocketFrame.Close if !o.decodeCloseRequests => None
case f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case DecodeResult.Value(v) => Some(v)
}
(decodeClose
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
}
.unNoneTerminate
.through(pipe)
.map(o.responses.encode)
.mergeHaltL(Stream.repeatEval(pongs.take))
Expand All @@ -47,22 +45,21 @@ private[http4s] object Http4sWebSockets {

private def http4sFrameToFrame(f: Http4sWebSocketFrame): WebSocketFrame =
f match {
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case Http4sWebSocketFrame.Ping(data) => WebSocketFrame.Ping(data.toArray)
case Http4sWebSocketFrame.Pong(data) => WebSocketFrame.Pong(data.toArray)
case c: Http4sWebSocketFrame.Close => WebSocketFrame.Close(c.closeCode, "")
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case x: Http4sWebSocketFrame.Ping => WebSocketFrame.Ping(x.data.toArray)
case x: Http4sWebSocketFrame.Pong => WebSocketFrame.Pong(x.data.toArray)
case c: Http4sWebSocketFrame.Close => WebSocketFrame.Close(c.closeCode, "")
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
}

private def frameToHttp4sFrame(w: WebSocketFrame): Http4sWebSocketFrame = {
private def frameToHttp4sFrame(w: WebSocketFrame): Http4sWebSocketFrame =
w match {
case WebSocketFrame.Text(p, finalFragment, _) => Http4sWebSocketFrame.Text(p, finalFragment)
case WebSocketFrame.Binary(p, finalFragment, _) => Http4sWebSocketFrame.Binary(ByteVector(p), finalFragment)
case WebSocketFrame.Ping(p) => Http4sWebSocketFrame.Ping(ByteVector(p))
case WebSocketFrame.Pong(p) => Http4sWebSocketFrame.Pong(ByteVector(p))
case WebSocketFrame.Close(code, reason) => Http4sWebSocketFrame.Close(code, reason).fold(throw _, identity)
case x: WebSocketFrame.Text => Http4sWebSocketFrame.Text(x.payload, x.finalFragment)
case x: WebSocketFrame.Binary => Http4sWebSocketFrame.Binary(ByteVector(x.payload), x.finalFragment)
case x: WebSocketFrame.Ping => Http4sWebSocketFrame.Ping(ByteVector(x.payload))
case x: WebSocketFrame.Pong => Http4sWebSocketFrame.Pong(ByteVector(x.payload))
case x: WebSocketFrame.Close => Http4sWebSocketFrame.Close(x.statusCode, x.reasonText).fold(throw _, identity)
}
}

private def optionallyConcatenateFrames[F[_]](s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = {
if (doConcatenate) {
Expand All @@ -87,7 +84,7 @@ private[http4s] object Http4sWebSockets {
private def optionallyIgnorePong[F[_]](s: Stream[F, WebSocketFrame], doIgnore: Boolean): Stream[F, WebSocketFrame] = {
if (doIgnore) {
s.filter {
case WebSocketFrame.Pong(_) => false
case _: WebSocketFrame.Pong => false
case _ => true
}
} else s
Expand All @@ -99,12 +96,18 @@ private[http4s] object Http4sWebSockets {
doAuto: Boolean
): Stream[F, WebSocketFrame] = {
if (doAuto) {
s.evalMap {
case WebSocketFrame.Ping(payload) => pongs.offer(WebSocketFrame.Pong(payload)).map(_ => none[WebSocketFrame])
case f => f.some.pure[F]
}.collect { case Some(f) =>
f
s.evalMapFilter {
case ping: WebSocketFrame.Ping => pongs.offer(WebSocketFrame.Pong(ping.payload)).as[Option[WebSocketFrame]](None)
case f => f.some.pure[F]
}
} else s
}

private def optionallyDecodeClose[F[_]](s: Stream[F, WebSocketFrame], doDecodeClose: Boolean): Stream[F, WebSocketFrame] =
if (!doDecodeClose) {
s.takeWhile {
case _: WebSocketFrame.Close => false
case _ => true
}
} else s
}

0 comments on commit d54927e

Please sign in to comment.