Skip to content

Commit

Permalink
Updated Http4sWebSockets, added ignorePing to WebSocketBodyOutput.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamil Kloch committed Dec 12, 2023
1 parent 6c34d14 commit 87d2a5c
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 71 deletions.
9 changes: 9 additions & 0 deletions core/src/main/scala/sttp/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ case class WebSocketBodyOutput[PIPE_REQ_RESP, REQ, RESP, T, S](
requestsInfo: Info[REQ],
responsesInfo: Info[RESP],
concatenateFragmentedFrames: Boolean,
ignorePing: Boolean,
ignorePong: Boolean,
autoPongOnPing: Boolean,
decodeCloseRequests: Boolean,
Expand Down Expand Up @@ -774,6 +775,14 @@ case class WebSocketBodyOutput[PIPE_REQ_RESP, REQ, RESP, T, S](
def concatenateFragmentedFrames(c: Boolean): WebSocketBodyOutput[PIPE_REQ_RESP, REQ, RESP, T, S] =
this.copy(concatenateFragmentedFrames = c)

/** Note: some interpreters ignore this setting.
*
* @param i
* If `true`, [[WebSocketFrame.Ping]] frames will be ignored and won't be passed to the codecs for decoding. Note that only some
* interpreters expose ping-pong frames.
*/
def ignorePing(i: Boolean): WebSocketBodyOutput[PIPE_REQ_RESP, REQ, RESP, T, S] = this.copy(ignorePing = i)

/** Note: some interpreters ignore this setting.
* @param i
* If `true`, [[WebSocketFrame.Pong]] frames will be ignored and won't be passed to the codecs for decoding. Note that only some
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/sttp/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ trait Tapir extends TapirExtensions with TapirComputedInputs with TapirStaticCon
EndpointIO.Info.empty,
EndpointIO.Info.empty,
concatenateFragmentedFrames = true,
ignorePing = false,
ignorePong = true,
autoPongOnPing = true,
decodeCloseRequests = requests.schema.isOptional,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve
import sttp.tapir.server.model.ServerResponse

import scala.reflect.ClassTag
import cats.NonEmptyParallel

class Http4sInvalidWebSocketUse(val message: String) extends Exception

Expand All @@ -27,7 +28,8 @@ class Http4sInvalidWebSocketUse(val message: String) extends Exception
trait Context[T]

trait Http4sServerInterpreter[F[_]] {
implicit def fa: Async[F]
implicit def fa: Async[F]
implicit def nep: NonEmptyParallel[F]

def http4sServerOptions: Http4sServerOptions[F] = Http4sServerOptions.default[F]

Expand Down Expand Up @@ -146,15 +148,17 @@ trait Http4sServerInterpreter[F[_]] {

object Http4sServerInterpreter {

def apply[F[_]]()(implicit _fa: Async[F]): Http4sServerInterpreter[F] = {
def apply[F[_]]()(implicit _fa: Async[F], _nep: NonEmptyParallel[F]): Http4sServerInterpreter[F] = {
new Http4sServerInterpreter[F] {
override implicit def fa: Async[F] = _fa
override implicit def nep: NonEmptyParallel[F] = _nep
}
}

def apply[F[_]](serverOptions: Http4sServerOptions[F])(implicit _fa: Async[F]): Http4sServerInterpreter[F] = {
def apply[F[_]](serverOptions: Http4sServerOptions[F])(implicit _fa: Async[F], _nep: NonEmptyParallel[F]): Http4sServerInterpreter[F] = {
new Http4sServerInterpreter[F] {
override implicit def fa: Async[F] = _fa
override implicit def nep: NonEmptyParallel[F] = _nep
override def http4sServerOptions: Http4sServerOptions[F] = serverOptions
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sttp.tapir.server.http4s

import cats.NonEmptyParallel
import cats.effect.{Async, Sync}
import cats.syntax.all._
import fs2.io.file.Files
Expand All @@ -17,7 +18,7 @@ import sttp.tapir.{CodecFormat, RawBodyType, RawPart, WebSocketBodyOutput}
import java.io.InputStream
import java.nio.charset.Charset

private[http4s] class Http4sToResponseBody[F[_]: Async](
private[http4s] class Http4sToResponseBody[F[_]: Async: NonEmptyParallel](
serverOptions: Http4sServerOptions[F]
) extends ToResponseBody[Http4sResponseBody[F], Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package sttp.tapir.server.http4s

import cats.Monad
import cats.effect.Temporal
import cats.effect.std.Queue
import cats.effect.{Temporal, Sync}
import cats.NonEmptyParallel
import cats.syntax.all._
import cats.{Monad, NonEmptyParallel, Applicative}
import fs2._
import fs2.concurrent.Channel
import org.http4s.websocket.{WebSocketFrame => Http4sWebSocketFrame}
import scodec.bits.ByteVector
import sttp.capabilities.fs2.Fs2Streams
Expand All @@ -13,102 +14,136 @@ import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.WebSocketFrame

private[http4s] object Http4sWebSockets {
def pipeToBody[F[_]: Temporal, REQ, RESP](
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = {
Queue.bounded[F, WebSocketFrame](2).map { pongs => (in: Stream[F, Http4sWebSocketFrame]) =>
val sttpFrames = in.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val ignorePongs = optionallyIgnorePong(concatenated, o.ignorePong)
val autoPongs = optionallyAutoPong(ignorePongs, pongs, o.autoPongOnPing)
val autoPings = o.autoPing match {
case Some((interval, frame)) => Stream.awakeEvery[F](interval).map(_ => frame)
case None => Stream.empty
}
val decodeClose = optionallyDecodeClose(autoPongs, o.decodeCloseRequests)
def pipeToBody[F[_]: NonEmptyParallel: Temporal, REQ, RESP](
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): F[Pipe[F, Http4sWebSocketFrame, Http4sWebSocketFrame]] = {
if ((!o.concatenateFragmentedFrames) && (!o.ignorePing) && (!o.ignorePong) && (!o.autoPongOnPing) && o.autoPing.isEmpty) {
// fast track: lift Http4sWebSocketFrames into REQ, run through pipe, convert RESP back to Http4sWebSocketFrame

(decodeClose
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
(in: Stream[F, Http4sWebSocketFrame]) =>
optionallyDecodeClose(in, o.decodeCloseRequests)
.map { http4sFrame =>
val f = http4sFrameToFrame(http4sFrame)
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
}
}
.through(pipe)
.map(o.responses.encode)
.mergeHaltL(Stream.repeatEval(pongs.take))
.mergeHaltL(autoPings) ++ Stream(WebSocketFrame.close))
.map(frameToHttp4sFrame)
.through(pipe)
.mapChunks(_.map(r => frameToHttp4sFrame(o.responses.encode(r))))
.append(Stream(frameToHttp4sFrame(WebSocketFrame.close)))
}.pure[F] else {
// concurrently merge business logic response, autoPings, autoPongOnPing
// use fs2.Channel to perform the merge (more efficient than Stream#mergeHaltL / Stream#parJoin)

Channel.bounded[F, Chunk[Http4sWebSocketFrame]](64).map { c =>
(in: Stream[F, Http4sWebSocketFrame]) =>
val decodeClose = optionallyDecodeClose(in, o.decodeCloseRequests)
val sttpFrames = decodeClose.map(http4sFrameToFrame)
val concatenated = optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
val autoPongs = optionallyAutoPong(concatenated, c, o.autoPongOnPing)
val ignorePingPongs = optionallyIgnorePingPong(autoPongs, o.ignorePing, o.ignorePong)
val autoPings = o.autoPing match {
case Some((interval, frame)) => (c.send(Chunk(frameToHttp4sFrame(frame))) >> Temporal[F].sleep(interval)).foreverM[Unit]
case None => Applicative[F].unit
}

val outputProducer = ignorePingPongs
.map { f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
}
.through(pipe)
.chunks
.foreach(chunk => c.send(chunk.map(r => frameToHttp4sFrame(o.responses.encode(r)))).void)
.compile
.drain

c.stream
.concurrently(Stream.exec((outputProducer >> c.close.void, autoPings).parTupled.void))
.append(Stream(Chunk(frameToHttp4sFrame(WebSocketFrame.close))))
.unchunks
}
}
}

private def http4sFrameToFrame(f: Http4sWebSocketFrame): WebSocketFrame =
f match {
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case x: Http4sWebSocketFrame.Ping => WebSocketFrame.Ping(x.data.toArray)
case x: Http4sWebSocketFrame.Pong => WebSocketFrame.Pong(x.data.toArray)
case t: Http4sWebSocketFrame.Text => WebSocketFrame.Text(t.str, t.last, None)
case x: Http4sWebSocketFrame.Ping => WebSocketFrame.Ping(x.data.toArray)
case x: Http4sWebSocketFrame.Pong => WebSocketFrame.Pong(x.data.toArray)
case c: Http4sWebSocketFrame.Close => WebSocketFrame.Close(c.closeCode, "")
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
case _ => WebSocketFrame.Binary(f.data.toArray, f.last, None)
}

private def frameToHttp4sFrame(w: WebSocketFrame): Http4sWebSocketFrame =
w match {
case x: WebSocketFrame.Text => Http4sWebSocketFrame.Text(x.payload, x.finalFragment)
case x: WebSocketFrame.Text => Http4sWebSocketFrame.Text(x.payload, x.finalFragment)
case x: WebSocketFrame.Binary => Http4sWebSocketFrame.Binary(ByteVector(x.payload), x.finalFragment)
case x: WebSocketFrame.Ping => Http4sWebSocketFrame.Ping(ByteVector(x.payload))
case x: WebSocketFrame.Pong => Http4sWebSocketFrame.Pong(ByteVector(x.payload))
case x: WebSocketFrame.Close => Http4sWebSocketFrame.Close(x.statusCode, x.reasonText).fold(throw _, identity)
case x: WebSocketFrame.Ping => Http4sWebSocketFrame.Ping(ByteVector(x.payload))
case x: WebSocketFrame.Pong => Http4sWebSocketFrame.Pong(ByteVector(x.payload))
case x: WebSocketFrame.Close => Http4sWebSocketFrame.Close(x.statusCode, x.reasonText).fold(throw _, identity)
}

private def optionallyConcatenateFrames[F[_]](s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = {
private def optionallyConcatenateFrames[F[_]](s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

s.mapAccumulate(None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collect { case (_, Some(f)) => f }
} else {
s
}
}

private def optionallyIgnorePong[F[_]](s: Stream[F, WebSocketFrame], doIgnore: Boolean): Stream[F, WebSocketFrame] = {
if (doIgnore) {
s.filter {
case _: WebSocketFrame.Pong => false
case _ => true
}
} else s
}

private def optionallyAutoPong[F[_]: Monad](
s: Stream[F, WebSocketFrame],
pongs: Queue[F, WebSocketFrame],
doAuto: Boolean
): Stream[F, WebSocketFrame] = {
private def optionallyIgnorePingPong[F[_]](s: Stream[F, WebSocketFrame], ignorePing: Boolean, ignorePong: Boolean): Stream[F, WebSocketFrame] =
(ignorePing, ignorePong) match {
case (false, false) => s
case (true, false) =>
s.filter {
case _: WebSocketFrame.Ping => false
case _ => true
}
case (false, true) =>
s.filter {
case _: WebSocketFrame.Pong => false
case _ => true
}
case (true, true) =>
s.filter {
case _: WebSocketFrame.Ping => false
case _: WebSocketFrame.Pong => false
case _ => true
}
}

private def optionallyAutoPong[F[_] : Monad](
s: Stream[F, WebSocketFrame],
c: Channel[F, Chunk[Http4sWebSocketFrame]],
doAuto: Boolean
): Stream[F, WebSocketFrame] =
if (doAuto) {
val trueF = true.pure[F]
s.evalFilter {
case ping: WebSocketFrame.Ping => pongs.offer(WebSocketFrame.Pong(ping.payload)).map(_ => false)
case _ => trueF
case ping: WebSocketFrame.Ping => c.send(Chunk(frameToHttp4sFrame(WebSocketFrame.Pong(ping.payload)))).map(_ => false)
case _ => trueF
}
} else s
}

private def optionallyDecodeClose[F[_]](s: Stream[F, WebSocketFrame], doDecodeClose: Boolean): Stream[F, WebSocketFrame] =
private def optionallyDecodeClose[F[_]](s: Stream[F, Http4sWebSocketFrame], doDecodeClose: Boolean): Stream[F, Http4sWebSocketFrame] =
if (!doDecodeClose) {
s.takeWhile {
case _: WebSocketFrame.Close => false
case _ => true
case _: Http4sWebSocketFrame.Close => false
case _ => true
}
} else s
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ object ConvertStreams {
w2.requestsInfo,
w2.responsesInfo,
w2.concatenateFragmentedFrames,
w2.ignorePing,
w2.ignorePong,
w2.autoPongOnPing,
w2.decodeCloseRequests,
Expand Down

0 comments on commit 87d2a5c

Please sign in to comment.