diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala index 56dab4c213..f711117d1b 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala @@ -11,22 +11,25 @@ import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.model.WebSocketFrameDecodeFailure import sttp.tapir.{DecodeResult, WebSocketBodyOutput} import sttp.ws.WebSocketFrame -import cats.effect.implicits.parallelForGenSpawn +import cats.effect.implicits._ private[http4s] object Http4sWebSockets { def pipeToBody[F[_]: Temporal, REQ, RESP]( - pipe: Pipe[F, REQ, RESP], - o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]] - ): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = { - if ((!o.concatenateFragmentedFrames) && (!o.ignorePong) && (!o.autoPongOnPing) && o.autoPing.isEmpty) { + pipe: Pipe[F, REQ, RESP], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = { + if ((!o.autoPongOnPing) && o.autoPing.isEmpty) { // fast track: lift Http4sWebSocketFrames into REQ, run through pipe, convert RESP back to Http4sWebSocketFrame (in: Stream[F, Http4sWebSocketFrame]) => - optionallyDecodeClose(in, o.decodeCloseRequests) - .map { http4sFrame => - val f = http4sFrameToFrame(http4sFrame) + val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests) + val sttpFrames = decodeClose.map(http4sFrameToFrame) + val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong) + ignorePongs + .map { f => o.requests.decode(f) match { - case x: DecodeResult.Value[REQ] => x.v + case x: DecodeResult.Value[REQ] => x.v case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) } } @@ -37,35 +40,35 @@ private[http4s] object Http4sWebSockets { // concurrently merge business logic response, autoPings, autoPongOnPing // use fs2.Channel to perform the merge (more efficient than Stream#mergeHaltL / Stream#parJoin) - Channel.bounded[F, Chunk[Http4sWebSocketFrame]](64).map { c => - (in: Stream[F, Http4sWebSocketFrame]) => - val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests) - val sttpFrames = decodeClose.map(http4sFrameToFrame) - val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) - val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong) - val autoPongs = optionallyAutoPong(ignorePongs, c, o.autoPongOnPing) - val autoPings = o.autoPing match { - case Some((interval, frame)) => (c.send(Chunk.singleton(frameToHttp4sFrame(frame))) >> Temporal[F].sleep(interval)).foreverM[Unit] - case None => Applicative[F].unit - } + Channel.bounded[F, Chunk[Http4sWebSocketFrame]](64).map { c => (in: Stream[F, Http4sWebSocketFrame]) => + val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests) + val sttpFrames = decodeClose.map(http4sFrameToFrame) + val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong) + val autoPongs = optionallyAutoPong(ignorePongs, c, o.autoPongOnPing) + val autoPings = o.autoPing match { + case Some((interval, frame)) => (c.send(Chunk.singleton(frameToHttp4sFrame(frame))) >> Temporal[F].sleep(interval)).foreverM[Unit] + case None => Applicative[F].unit + } - val outputProducer = autoPongs - .map { f => - o.requests.decode(f) match { - case x: DecodeResult.Value[REQ] => x.v - case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) - } + val outputProducer = autoPongs + .map { f => + o.requests.decode(f) match { + case x: DecodeResult.Value[REQ] => x.v + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) } - .through(pipe) - .chunks - .foreach(chunk => c.send(chunk.map(r => frameToHttp4sFrame(o.responses.encode(r)))).void) - .compile - .drain + } + .through(pipe) + .chunks + .foreach(chunk => c.send(chunk.map(r => frameToHttp4sFrame(o.responses.encode(r)))).void) + .compile + .drain + + val outcomes = (outputProducer.guarantee(c.close.void), autoPings).parTupled.void - c.stream - .concurrently(Stream.exec((outputProducer >> c.close.void, autoPings).parTupled.void)) - .append(Stream(Chunk.singleton(frameToHttp4sFrame(WebSocketFrame.close)))) - .unchunks + Stream + .bracket(Temporal[F].start(outcomes))(f => f.cancel >> f.joinWithUnit) >> + c.stream.append(Stream(Chunk.singleton(frameToHttp4sFrame(WebSocketFrame.close)))).unchunks } } }