diff --git a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala index df62f8e3d8..58f2bd89af 100644 --- a/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala +++ b/server/vertx-server/cats/src/main/scala/sttp/tapir/server/vertx/cats/streams/fs2.scala @@ -1,10 +1,11 @@ package sttp.tapir.server.vertx.cats.streams import _root_.fs2.{Chunk, Stream} -import cats.effect.kernel.Async +import _root_.fs2.concurrent.Channel import cats.effect.kernel.Resource.ExitCase.{Canceled, Errored, Succeeded} -import cats.effect.{Deferred, Ref} +import cats.effect.{Deferred, Ref, Sync, Async, GenSpawn} import cats.syntax.all._ +import cats.effect.implicits._ import io.vertx.core.Handler import io.vertx.core.buffer.Buffer import io.vertx.core.streams.ReadStream @@ -29,7 +30,9 @@ object fs2 { dfd.get } - implicit def fs2ReadStreamCompatible[F[_]](opts: VertxCatsServerOptions[F])(implicit F: Async[F]): ReadStreamCompatible[Fs2Streams[F]] = { + implicit def fs2ReadStreamCompatible[F[_]: Async]( + opts: VertxCatsServerOptions[F] + ): ReadStreamCompatible[Fs2Streams[F]] = { new ReadStreamCompatible[Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] @@ -41,13 +44,13 @@ object fs2 { for { promise <- Deferred[F, Unit] state <- Ref.of(StreamState.empty[F, O](promise)) - _ <- F.start( + _ <- GenSpawn[F].start( stream .evalMap({ chunk => val buffer = fn(chunk) state.get.flatMap { case StreamState(None, handler, _, _) => - F.delay(handler.handle(buffer)) + Sync[F].delay(handler.handle(buffer)) case StreamState(Some(promise), _, _, _) => for { _ <- promise.get @@ -60,15 +63,15 @@ object fs2 { .onFinalizeCase({ case Succeeded => state.get.flatMap { state => - F.delay(state.endHandler.handle(null)) + Sync[F].delay(state.endHandler.handle(null)) } case Canceled => state.get.flatMap { state => - F.delay(state.errorHandler.handle(new Exception("Cancelled!"))) + Sync[F].delay(state.errorHandler.handle(new Exception("Cancelled!"))) } case Errored(cause) => state.get.flatMap { state => - F.delay(state.errorHandler.handle(cause)) + Sync[F].delay(state.errorHandler.handle(cause)) } }) .compile @@ -136,7 +139,7 @@ object fs2 { } yield result } - _ <- F.start( + _ <- GenSpawn[F].start( Stream .unfoldEval[F, Unit, ActivationEvent](())({ _ => for { @@ -151,8 +154,8 @@ object fs2 { } yield result.map((_, ())) }) .evalMap({ - case Pause => F.delay(readStream.pause()) - case Resume => F.delay(readStream.resume()) + case Pause => Sync[F].delay(readStream.pause()) + case Resume => Sync[F].delay(readStream.resume()) }) .compile .drain @@ -173,36 +176,76 @@ object fs2 { } } - override def webSocketPipe[REQ, RESP]( - readStream: ReadStream[WebSocketFrame], + private def decodeFrame[REQ, RESP]( + frame: WebSocketFrame, + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): REQ = { + o.requests.decode(frame) match { + case DecodeResult.Value(v) => v + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(frame, failure) + } + } + + private def fastPath[REQ, RESP]( + in: Stream[F, WebSocketFrame], pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] ): ReadStream[WebSocketFrame] = { - val stream0 = fromReadStreamInternal(readStream) - val stream1 = optionallyContatenateFrames(stream0, o.concatenateFragmentedFrames) - val stream2 = optionallyIgnorePong(stream1, o.ignorePong) - val autoPings = o.autoPing match { - case Some((interval, frame)) => - Stream.awakeEvery(interval).as(frame) - case None => - Stream.empty - } + val contatenatedFrames = optionallyContatenateFrames(in, o.concatenateFragmentedFrames) + val ignorePongs = optionallyIgnorePong(contatenatedFrames, o.ignorePong) - val stream3 = stream2 - .map { frame => - o.requests.decode(frame) match { - case DecodeResult.Value(v) => - v - case failure: DecodeResult.Failure => - throw new WebSocketFrameDecodeFailure(frame, failure) - } - } + val stream = ignorePongs + .map(decodeFrame(_, o)) .through(pipe) .map(o.responses.encode) - .mergeHaltL(autoPings) .append(Stream(WebSocketFrame.close)) - mapToReadStream[WebSocketFrame, WebSocketFrame](stream3, identity) + mapToReadStream[WebSocketFrame, WebSocketFrame](stream, identity) + } + + private def standardPath[REQ, RESP]( + in: Stream[F, WebSocketFrame], + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): ReadStream[WebSocketFrame] = { + val mergedStream = Channel.bounded[F, Chunk[WebSocketFrame]](64).map { c => + val contatenatedFrames = optionallyContatenateFrames(in, o.concatenateFragmentedFrames) + val ignoredPongs = optionallyIgnorePong(contatenatedFrames, o.ignorePong) + + val autoPings = o.autoPing match { + case Some((interval, frame)) => Stream.awakeEvery(interval).as(frame) + case None => Stream.empty + } + + val outputProducer = ignoredPongs + .map(decodeFrame(_, o)) + .through(pipe) + .chunks + .foreach(chunk => c.send(chunk.map(r => o.responses.encode(r))).void) + .compile + .drain + + val outcomes = (outputProducer.guarantee(c.close.void), autoPings.compile.drain).parTupled.void + + Stream.bracket(outcomes.start)(f => f.cancel >> f.joinWithUnit) >> + c.stream.append(Stream(Chunk.singleton(WebSocketFrame.close))).unchunks + } + + mapToReadStream[WebSocketFrame, WebSocketFrame](Stream.force(mergedStream), identity) + } + + override def webSocketPipe[REQ, RESP]( + readStream: ReadStream[WebSocketFrame], + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): ReadStream[WebSocketFrame] = { + val stream = fromReadStreamInternal(readStream) + + if ((!o.autoPongOnPing) && o.autoPing.isEmpty) { + fastPath(stream, pipe, o) + } else { + standardPath(stream, pipe, o) + } } def optionallyContatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =