diff --git a/build.sbt b/build.sbt index 3cccfc5086..feed8f00ed 100644 --- a/build.sbt +++ b/build.sbt @@ -2196,6 +2196,7 @@ lazy val documentation: ProjectMatrix = (projectMatrix in file("generated-doc")) sprayJson, http4sClient, http4sServerZio, + nettyServerCats, sttpClient, playClient, sttpStubServer, diff --git a/doc/server/netty.md b/doc/server/netty.md index 3fd9cec741..cbeee80c9c 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -83,6 +83,71 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None) NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` +## Web sockets + +The netty-cats interpreter supports web sockets, with pipes of type `fs2.Pipe[F, REQ, RESP]`. See [web sockets](../endpoint/websockets.md) +for more details. + +To create a web socket endpoint, use Tapir's `out(webSocketBody)` output type: + +```scala mdoc:compile-only +import cats.effect.kernel.Resource +import cats.effect.{IO, ResourceApp} +import cats.syntax.all._ +import fs2.Pipe +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir._ +import sttp.tapir.server.netty.cats.NettyCatsServer +import sttp.ws.WebSocketFrame + +import scala.concurrent.duration._ + +object WebSocketsNettyCatsServer extends ResourceApp.Forever { + + // Web socket endpoint + val wsEndpoint = + endpoint.get + .in("ws") + .out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO]) + .concatenateFragmentedFrames(false) // All these options are supported by tapir-netty + .ignorePong(true) + .autoPongOnPing(true) + .decodeCloseRequests(false) + .decodeCloseResponses(false) + .autoPing(Some((10.seconds, WebSocketFrame.Ping("ping-content".getBytes)))) + ) + + // Your processor transforming a stream of requests into a stream of responses + val pipe: Pipe[IO, String, String] = requestStream => requestStream.evalMap(str => IO.pure(str.toUpperCase)) + // Alternatively, requests can be ignored and the backend can be turned into a stream emitting frames to the client: + // val pipe: Pipe[IO, String, String] = requestStream => someDataEmittingStream.concurrently(requestStream.as(())) + + val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => IO.pure(pipe)) + + // A regular /GET endpoint + val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = + endpoint.get.in("hello").in(query[String]("name")).out(stringBody) + + val helloWorldServerEndpoint = helloWorldEndpoint + .serverLogicSuccess(name => IO.pure(s"Hello, $name!")) + + override def run(args: List[String]) = NettyCatsServer + .io() + .flatMap { server => + Resource + .make( + server + .port(8080) + .host("localhost") + .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) + .start() + )(_.stop()) + .as(()) + } +} +``` + ## Graceful shutdown A Netty server can be gracefully closed using the function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources. diff --git a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala index 85d3ac7c2b..10639eb7a9 100644 --- a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala +++ b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala @@ -3,16 +3,36 @@ package sttp.tapir.perf.netty.cats import cats.effect.IO import cats.effect.kernel.Resource import cats.effect.std.Dispatcher +import fs2.Stream +import sttp.tapir.{CodecFormat, webSocketBody} import sttp.tapir.perf.Common._ import sttp.tapir.perf.apis._ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.cats.NettyCatsServer import sttp.tapir.server.netty.cats.NettyCatsServerOptions +import sttp.ws.WebSocketFrame +import sttp.capabilities.fs2.Fs2Streams -object Tapir extends Endpoints +import scala.concurrent.duration._ -object NettyCats { +object Tapir extends Endpoints { + val wsResponseStream = Stream.fixedRate[IO](WebSocketSingleResponseLag, dampen = false) + val wsEndpoint = wsBaseEndpoint + .out( + webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](Fs2Streams[IO]) + .concatenateFragmentedFrames(false) + .autoPongOnPing(false) + .ignorePong(true) + .autoPing(None) + ) +} +object NettyCats { + val wsServerEndpoint = Tapir.wsEndpoint.serverLogicSuccess(_ => + IO.pure { (in: Stream[IO, Long]) => + Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(())) + } + ) def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = { val declaredPort = Port val declaredHost = "0.0.0.0" @@ -25,7 +45,7 @@ object NettyCats { server .port(declaredPort) .host(declaredHost) - .addEndpoints(endpoints) + .addEndpoints(wsServerEndpoint :: endpoints) .start() )(binding => binding.stop()) } yield ()).allocated.map(_._2) diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 3cb5e63b73..db589bce84 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -157,7 +157,13 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ - new ServerWebSocketTests(createServerTest, AkkaStreams) { + new ServerWebSocketTests( + createServerTest, + AkkaStreams, + autoPing = false, + failingPipe = true, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 844e628024..7f6028b33a 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -5,6 +5,7 @@ import cats.effect._ import cats.effect.unsafe.implicits.global import cats.syntax.all._ import fs2.Pipe +import fs2.Stream import org.http4s.blaze.server.BlazeServerBuilder import org.http4s.server.Router import org.http4s.server.ContextMiddleware @@ -138,7 +139,13 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { + new ServerWebSocketTests( + createServerTest, + Fs2Streams[IO], + autoPing = true, + failingPipe = true, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty }.tests() ++ diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index 99ff7c2d4d..5aafe6caa7 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -55,7 +55,13 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = false, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ diff --git a/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala b/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala new file mode 100644 index 0000000000..f88a74ed12 --- /dev/null +++ b/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala @@ -0,0 +1,12 @@ +package sttp.tapir.server.netty.cats.internal + +import scala.concurrent.ExecutionContext + +object ExecutionContexts { + val sameThread: ExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = runnable.run() + + override def reportFailure(cause: Throwable): Unit = + ExecutionContext.defaultReporter(cause) + } +} diff --git a/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala b/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala new file mode 100644 index 0000000000..dacc9e5473 --- /dev/null +++ b/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala @@ -0,0 +1,7 @@ +package sttp.tapir.server.netty.cats.internal + +import scala.concurrent.ExecutionContext + +object ExecutionContexts { + val sameThread: ExecutionContext = ExecutionContext.parasitic +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 1b7a449326..8ebe94f74c 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -8,6 +8,7 @@ import io.netty.channel._ import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress import io.netty.util.concurrent.DefaultEventExecutor +import sttp.capabilities.WebSockets import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError @@ -25,13 +26,16 @@ import scala.concurrent.Future import scala.concurrent.duration._ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { - def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) - def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = + def addEndpoint(se: ServerEndpoint[Fs2Streams[F] with WebSockets, F]): NettyCatsServer[F] = addEndpoints(List(se)) + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F] with WebSockets, overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( + def addEndpoints( + ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], + overrideOptions: NettyCatsServerOptions[F] + ): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(overrideOptions).toRoute(ses) ) @@ -74,7 +78,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 9e1c4722c0..8f65e39a73 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -14,12 +14,13 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.capabilities.WebSockets trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] def nettyServerOptions: NettyCatsServerOptions[F] - def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = { + def toRoute(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): Route[F] = { implicit val monad: MonadError[F] = new CatsMonadError[F] val runAsync = new RunAsync[F] { @@ -31,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val createFile = nettyServerOptions.createFile val deleteFile = nettyServerOptions.deleteFile - val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( + val serverInterpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 3849fb4a19..8f257f3559 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -1,22 +1,20 @@ package sttp.tapir.server.netty.cats.internal +import cats.effect.kernel.{Async, Sync} +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.io.file.{Files, Flags, Path} +import fs2.{Chunk, Pipe} import io.netty.buffer.Unpooled +import io.netty.channel.{ChannelFuture, ChannelHandlerContext} +import io.netty.handler.codec.http.websocketx._ import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher -import sttp.tapir.FileRange +import org.reactivestreams.{Processor, Publisher} +import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream -import cats.effect.std.Dispatcher -import sttp.capabilities.fs2.Fs2Streams -import fs2.io.file.Path -import fs2.io.file.Files -import cats.effect.kernel.Async -import fs2.io.file.Flags -import fs2.interop.reactivestreams.StreamUnicastPublisher -import cats.effect.kernel.Sync -import fs2.Chunk -import fs2.interop.reactivestreams.StreamSubscriber object Fs2StreamCompatible { @@ -68,6 +66,26 @@ object Fs2StreamCompatible { override def emptyStream: streams.BinaryStream = fs2.Stream.empty + override def asWsProcessor[REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], + ctx: ChannelHandlerContext + ): Processor[WebSocketFrame, WebSocketFrame] = { + val wsCompletedPromise = ctx.newPromise() + wsCompletedPromise.addListener((f: ChannelFuture) => { + // A special callback that has to be used when a SteramSubscription cancels or fails. + // This can happen in case of errors in the pipeline which are not signalled correctly, + // like throwing exceptions directly. + // Without explicit Close frame a client may hang on waiting and not knowing about closed channel. + if (f.isCancelled) { + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Canceled")) + } else if (!f.isSuccess) { + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Error")) + } + }) + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, wsCompletedPromise) + } + private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index 4dfe6c22fb..98b0742b74 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -12,6 +12,7 @@ import sttp.tapir.TapirFile import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} +import sttp.capabilities.WebSockets private[cats] class NettyCatsRequestBody[F[_]: Async]( val createFile: ServerRequest => F[TapirFile], @@ -24,7 +25,8 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = - (toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream]) + (toStream(serverRequest, maxBytes) + .asInstanceOf[streamCompatible.streams.BinaryStream]) .through( Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala new file mode 100644 index 0000000000..26e50651de --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -0,0 +1,131 @@ +package sttp.tapir.server.netty.cats.internal + +import cats.Applicative +import cats.effect.kernel.Resource.ExitCase +import cats.effect.kernel.{Async, Sync} +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.{Pipe, Stream} +import io.netty.channel.ChannelPromise +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription} +import org.slf4j.LoggerFactory +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._ +import sttp.tapir.{DecodeResult, WebSocketBodyOutput} +import sttp.ws.WebSocketFrame + +import scala.concurrent.Promise +import scala.util.Success + +/** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. + */ +class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + dispatcher: Dispatcher[F], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], + wsCompletedPromise: ChannelPromise +) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { + // Not really that unsafe. Subscriber creation doesn't do any IO, only initializes an AtomicReference in an initial state. + private val subscriber: StreamSubscriber[F, NettyWebSocketFrame] = dispatcher.unsafeRunSync( + // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered + StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) + ) + private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() + private val logger = LoggerFactory.getLogger(getClass.getName) + + override def onSubscribe(s: Subscription): Unit = { + val subscription = new NonCancelingSubscription(s) + val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) + val sttpFrames = in.map { f => + val sttpFrame = nettyFrameToFrame(f) + f.release() + sttpFrame + } + val stream: Stream[F, NettyWebSocketFrame] = + optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + .map(f => + o.requests.decode(f) match { + case x: DecodeResult.Value[REQ] => x.v + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + } + ) + .through(pipe) + .map(r => frameToNettyFrame(o.responses.encode(r))) + .onFinalizeCaseWeak { + case ExitCase.Succeeded => + Sync[F].delay { val _ = wsCompletedPromise.setSuccess() } + case ExitCase.Errored(t) => + Sync[F].delay(wsCompletedPromise.setFailure(t)) >> Sync[F].delay(logger.error("Error occured in WebSocket channel", t)) + case ExitCase.Canceled => + Sync[F].delay { val _ = wsCompletedPromise.cancel(true) } + } + .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) + + // Trigger listening for WS frames in the underlying fs2 StreamSubscribber + subscriber.sub.onSubscribe(subscription) + // Signal that a Publisher is ready to send result frames + publisher.success(StreamUnicastPublisher(stream, dispatcher)) + } + + override def onNext(t: NettyWebSocketFrame): Unit = { + subscriber.sub.onNext(t) + } + + override def onError(t: Throwable): Unit = { + subscriber.sub.onError(t) + if (!wsCompletedPromise.isDone()) { + val _ = wsCompletedPromise.setFailure(t) + } + } + + override def onComplete(): Unit = { + subscriber.sub.onComplete() + if (!wsCompletedPromise.isDone()) { + val _ = wsCompletedPromise.setSuccess() + } + } + + override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { + // A subscriber may come to read from our internal Publisher. It has to wait for the Publisher to be initialized. + publisher.future.onComplete { + case Success(p) => + p.subscribe(s) + case _ => // Never happens, we call succecss() explicitly + }(ExecutionContexts.sameThread) + + } + + private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = + if (doConcatenate) { + type Accumulator = Option[Either[Array[Byte], String]] + + s.mapAccumulate(None: Accumulator) { + case (None, f: WebSocketFrame.Ping) => (None, Some(f)) + case (None, f: WebSocketFrame.Pong) => (None, Some(f)) + case (None, f: WebSocketFrame.Close) => (None, Some(f)) + case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + }.collect { case (_, Some(f)) => f } + } else s +} + +/** A special wrapper used to override internal logic of fs2, which calls cancel() silently when internal stream failures happen, causing + * the subscription to close the channel and stop the subscriber in such a way that errors can't get handled properly. With this wrapper we + * intentionally don't do anything on cancel(), so that the stream continues to fail properly on errors. We are handling cancelation + * manually with a channel promise passed to the processor logic. + * @param delegate + * a channel subscription which we don't want to notify about cancelation. + */ +class NonCancelingSubscription(delegate: Subscription) extends Subscription { + override def cancel(): Unit = () + override def request(n: Long): Unit = { + delegate.request(n) + } +} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index e7a1280719..73053ed362 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -40,7 +40,17 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ - new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() + new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ + new ServerWebSocketTests( + createServerTest, + Fs2Streams[IO], + autoPing = true, + failingPipe = true, + handlePong = true + ) { + override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) + override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty + }.tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 61be7e6f4d..9334504a2f 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -4,6 +4,7 @@ import cats.data.NonEmptyList import cats.effect.std.Dispatcher import cats.effect.{IO, Resource} import io.netty.channel.nio.NioEventLoopGroup +import sttp.capabilities.WebSockets import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter @@ -12,8 +13,8 @@ import sttp.capabilities.fs2.Fs2Streams import scala.concurrent.duration.FiniteDuration class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) - extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { - override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = { + extends TestServerInterpreter[IO, Fs2Streams[IO] with WebSockets, NettyCatsServerOptions[IO], Route[IO]] { + override def route(es: List[ServerEndpoint[Fs2Streams[IO] with WebSockets, IO]], interceptors: Interceptors): Route[IO] = { val serverOptions: NettyCatsServerOptions[IO] = interceptors( NettyCatsServerOptions.customiseInterceptors[IO](dispatcher) ).options diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 46918fc8fd..85fc5bdd26 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -98,7 +98,8 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, unsafeRunF, channelGroup, isShuttingDown, - config.serverHeader + config.serverHeader, + config.isSsl ), eventLoopGroup, socketOverride diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 07e43954c5..d098a6874c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -10,6 +10,7 @@ import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext import org.playframework.netty.http.HttpStreamsServerHandler import sttp.tapir.server.netty.NettyConfig.EventLoopConfig +import sttp.tapir.server.netty.internal._ import scala.concurrent.duration._ @@ -102,9 +103,12 @@ case class NettyConfig( def noGracefulShutdown = copy(gracefulShutdownTimeout = None) def serverHeader(h: String): NettyConfig = copy(serverHeader = Some(h)) + + def isSsl: Boolean = sslContext.isDefined } object NettyConfig { + def default: NettyConfig = NettyConfig( host = "localhost", port = 8080, @@ -126,7 +130,7 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast(new HttpServerCodec()) + pipeline.addLast(ServerCodecHandlerName, new HttpServerCodec()) pipeline.addLast(new HttpStreamsServerHandler()) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index f510e8e85b..0d4285f4e5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -71,7 +71,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index b17335f60b..77882e8dfd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -3,8 +3,12 @@ package sttp.tapir.server.netty import io.netty.buffer.ByteBuf import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.HttpContent +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import org.reactivestreams.Publisher +import org.reactivestreams.{Processor, Publisher} +import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} + +import scala.concurrent.duration.FiniteDuration sealed trait NettyResponseContent { def channelPromise: ChannelPromise @@ -17,4 +21,12 @@ object NettyResponseContent { final case class ChunkedFileNettyResponseContent(channelPromise: ChannelPromise, chunkedFile: ChunkedFile) extends NettyResponseContent final case class ReactivePublisherNettyResponseContent(channelPromise: ChannelPromise, publisher: Publisher[HttpContent]) extends NettyResponseContent + final case class ReactiveWebSocketProcessorNettyResponseContent( + channelPromise: ChannelPromise, + processor: Processor[WebSocketFrame, WebSocketFrame], + ignorePong: Boolean, + autoPongOnPing: Boolean, + decodeCloseRequests: Boolean, + autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] + ) extends NettyResponseContent } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 71d57519a8..07edb5ace2 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.netty.internal import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{Channel, ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{Channel, ChannelFuture, ChannelHandler, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.handler.timeout.ReadTimeoutHandler import sttp.tapir.server.netty.NettyConfig @@ -9,6 +9,8 @@ import java.net.{InetSocketAddress, SocketAddress} object NettyBootstrap { + private val ReadTimeoutHandlerName = "readTimeoutHandler" + def apply[F[_]]( nettyConfig: NettyConfig, handler: => NettyServerHandler[F], @@ -27,7 +29,10 @@ object NettyBootstrap { nettyConfig.requestTimeout match { case Some(requestTimeout) => - nettyConfigBuilder(ch.pipeline().addLast(new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), handler) + nettyConfigBuilder( + ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), + handler, + ) case None => nettyConfigBuilder(ch.pipeline(), handler) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index c3333a7ff4..8ce0004b9a 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,8 +4,10 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} +import io.netty.handler.timeout.ReadTimeoutHandler +import org.playframework.netty.http.{DefaultStreamedHttpResponse, DefaultWebSocketHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher import org.slf4j.LoggerFactory import sttp.monad.MonadError @@ -15,8 +17,10 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ByteBufNettyResponseContent, ChunkedFileNettyResponseContent, ChunkedStreamNettyResponseContent, - ReactivePublisherNettyResponseContent + ReactivePublisherNettyResponseContent, + ReactiveWebSocketProcessorNettyResponseContent } +import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import java.util.concurrent.atomic.AtomicBoolean @@ -36,7 +40,8 @@ class NettyServerHandler[F[_]]( unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), channelGroup: ChannelGroup, isShuttingDown: AtomicBoolean, - serverHeader: Option[String] + serverHeader: Option[String], + isSsl: Boolean = false )(implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { @@ -62,6 +67,7 @@ class NettyServerHandler[F[_]]( private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]] private val logger = LoggerFactory.getLogger(getClass.getName) + private final val WebSocketAutoPingHandlerName = "wsAutoPingHandler" override def handlerAdded(ctx: ChannelHandlerContext): Unit = if (ctx.channel.isActive) { @@ -212,6 +218,17 @@ class NettyServerHandler[F[_]]( ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, + wsHandler = (responseContent) => { + if (isWsHandshake(req)) + initWsPipeline(ctx, responseContent, req) + else { + val buf = Unpooled.wrappedBuffer("Incorrect Web Socket handshake".getBytes) + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, buf) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, buf.readableBytes()) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res).closeIfNeeded(req) + } + }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( req.protocolVersion(), @@ -227,13 +244,64 @@ class NettyServerHandler[F[_]]( } ) - private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { + private def initWsPipeline( + ctx: ChannelHandlerContext, + r: ReactiveWebSocketProcessorNettyResponseContent, + handshakeReq: HttpRequest + ) = { + ctx.pipeline().remove(this) + ctx.pipeline().remove(classOf[ReadTimeoutHandler]) + ctx + .pipeline() + .addAfter( + ServerCodecHandlerName, + WebSocketControlFrameHandlerName, + new NettyControlFrameHandler( + ignorePong = r.ignorePong, + autoPongOnPing = r.autoPongOnPing, + decodeCloseRequests = r.decodeCloseRequests + ) + ) + r.autoPing.foreach { case (interval, pingMsg) => + ctx + .pipeline() + .addAfter( + WebSocketControlFrameHandlerName, + WebSocketAutoPingHandlerName, + new WebSocketAutoPingHandler(interval, pingMsg) + ) + } + // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly + r.channelPromise.setSuccess() + val _ = ctx.writeAndFlush( + // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler + // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) + new DefaultWebSocketHttpResponse( + handshakeReq.protocolVersion(), + HttpResponseStatus.valueOf(200), + r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler + new WebSocketServerHandshakerFactory(wsUrl(handshakeReq), null, false) + ) + ) + } + + private def isWsHandshake(req: HttpRequest): Boolean = + "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && + "Upgrade".equalsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION)) + + // Only ancient WS protocol versions will use this in the response header. + private def wsUrl(req: HttpRequest): String = { + val scheme = if (isSsl) "wss" else "ws" + s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" + } + private implicit class RichServerNettyResponse(r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, byteBufHandler: (ChannelPromise, ByteBuf) => Unit, chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, + wsHandler: ReactiveWebSocketProcessorNettyResponseContent => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -241,10 +309,11 @@ class NettyServerHandler[F[_]]( val values = function(ctx) values match { - case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) - case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) - case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) - case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) + case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) + case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r) } } case None => noBodyHandler() @@ -252,7 +321,7 @@ class NettyServerHandler[F[_]]( } } - private implicit class RichHttpMessage(val m: HttpMessage) { + private implicit class RichHttpMessage(m: HttpMessage) { def setHeadersFrom(response: ServerResponse[_]): Unit = { serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) response.headers @@ -278,7 +347,7 @@ class NettyServerHandler[F[_]]( } } - private implicit class RichChannelFuture(val cf: ChannelFuture) { + private implicit class RichChannelFuture(cf: ChannelFuture) { def closeIfNeeded(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request) || isShuttingDown.get()) { cf.addListener(ChannelFutureListener.CLOSE) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 0f335a7b14..507eee8a70 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -6,18 +6,23 @@ import sttp.capabilities.Streams import sttp.model.HasHeaders import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent.{ + ByteBufNettyResponseContent, + ReactivePublisherNettyResponseContent, + ReactiveWebSocketProcessorNettyResponseContent +} import sttp.tapir.server.netty.internal.NettyToResponseBody._ -import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.nio.ByteBuffer import java.nio.charset.Charset -/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries - * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. - * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. +/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries like fs2 + * or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. Other kinds + * of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. */ -private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { +private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) + extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams @@ -37,7 +42,10 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => - new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None)) + new ReactivePublisherNettyResponseContent( + ctx.newPromise(), + streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None) + ) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => @@ -47,7 +55,8 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl ) case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } @@ -68,5 +77,18 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl override def fromWebSocketPipe[REQ, RESP]( pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S] - ): NettyResponse = throw new UnsupportedOperationException + ): NettyResponse = (ctx: ChannelHandlerContext) => { + new ReactiveWebSocketProcessorNettyResponseContent( + ctx.newPromise(), + streamCompatible.asWsProcessor( + pipe.asInstanceOf[streamCompatible.streams.Pipe[REQ, RESP]], + o.asInstanceOf[WebSocketBodyOutput[streamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]], + ctx + ), + ignorePong = o.ignorePong, + autoPongOnPing = o.autoPongOnPing, + decodeCloseRequests = o.decodeCloseRequests, + autoPing = o.autoPing + ) + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 6d5da177bd..bbcdb4ea3e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -1,16 +1,17 @@ package sttp.tapir.server.netty.internal +import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.http.HttpContent -import org.reactivestreams.Publisher +import io.netty.handler.codec.http.websocketx.WebSocketFrame +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.Streams -import sttp.tapir.FileRange +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream -/** - * Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. - * This includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. - * We also need implementation of a failed (errored) stream, as well as an empty stream (for handling empty requests). +/** Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. This + * includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. We also need implementation of a + * failed (errored) stream, as well as an empty stream (for handling empty requests). */ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S @@ -27,4 +28,10 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = asPublisher(fromInputStream(is, chunkSize, length)) + + def asWsProcessor[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S], + ctx: ChannelHandlerContext + ): Processor[WebSocketFrame, WebSocketFrame] } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala index 12ac1afd26..4a4659d14e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala @@ -1,13 +1,34 @@ package sttp.tapir.server.netty -import io.netty.handler.codec.http.HttpHeaders +import io.netty.channel.{ChannelFuture, ChannelFutureListener} +import io.netty.handler.codec.http.{HttpHeaderNames, HttpHeaders, HttpMessage} import sttp.model.Header +import sttp.tapir.server.model.ServerResponse import scala.collection.JavaConverters._ package object internal { - implicit class RichNettyHttpHeaders(underlying: HttpHeaders) { + implicit class RichNettyHttpHeaders(private val underlying: HttpHeaders) extends AnyVal { def toHeaderSeq: List[Header] = underlying.asScala.map(e => Header(e.getKey, e.getValue)).toList } + + implicit class RichChannelFuture(val cf: ChannelFuture) { + def close(): Unit = { + val _ = cf.addListener(ChannelFutureListener.CLOSE) + } + } + + implicit class RichHttpMessage(private val m: HttpMessage) extends AnyVal { + def setHeadersFrom(response: ServerResponse[_], serverHeader: Option[String]): Unit = { + serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) + response.headers + .groupBy(_.name) + .foreach { case (k, v) => + m.headers().set(k, v.map(_.value).asJava) + } + } + } + final val ServerCodecHandlerName = "serverCodecHandler" + final val WebSocketControlFrameHandlerName = "wsControlFrameHandler" } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala new file mode 100644 index 0000000000..49838d1f59 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala @@ -0,0 +1,46 @@ +package sttp.tapir.server.netty.internal.ws + +import io.netty.buffer.Unpooled +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx._ +import org.slf4j.LoggerFactory + +import java.util.concurrent.{ScheduledFuture, TimeUnit} +import scala.concurrent.duration.FiniteDuration + +/** If auto ping is enabled for an endpoint, this handler will be plugged into the pipeline. Its responsibility is to manage start and stop + * of the ping scheduler. + * @param pingInterval + * time interval to be used between sending pings to the client. + * @param frame + * desired content of the Ping frame, as specified in the Tapir endpoint output. + */ +class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebSocketFrame.Ping) extends ChannelInboundHandlerAdapter { + val nettyFrame = new PingWebSocketFrame(Unpooled.copiedBuffer(frame.payload)) + private var pingTask: ScheduledFuture[_] = _ + + private val logger = LoggerFactory.getLogger(getClass.getName) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = { + super.handlerAdded(ctx) + if (ctx.channel.isActive) { + logger.debug(s"STARTING WebSocket Ping scheduler for channel ${ctx.channel}, interval = $pingInterval") + val sendPing: Runnable = new Runnable { + override def run(): Unit = { + logger.trace(s"Sending PING WebSocket frame for channel ${ctx.channel}") + val _ = ctx.writeAndFlush(nettyFrame.retain()) + } + } + pingTask = + ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) + } + } + + override def channelInactive(ctx: ChannelHandlerContext): Unit = { + super.channelInactive(ctx) + logger.debug(s"STOPPING WebSocket Ping scheduler for channel ${ctx.channel}") + if (pingTask != null) { + val _ = pingTask.cancel(false) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala new file mode 100644 index 0000000000..04cfb931be --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala @@ -0,0 +1,40 @@ +package sttp.tapir.server.netty.internal.ws + +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, PingWebSocketFrame, PongWebSocketFrame} +import sttp.tapir.server.netty.internal._ + +/** Handles Ping, Pong, and Close frames for WebSockets. + */ +class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean) + extends ChannelInboundHandlerAdapter { + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + msg match { + case ping: PingWebSocketFrame => + if (autoPongOnPing) { + val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + } else { + val _ = ping.content().release() + } + case pong: PongWebSocketFrame => + if (!ignorePong) { + val _ = ctx.fireChannelRead(pong) + } else { + val _ = pong.content().release() + } + case close: CloseWebSocketFrame => + if (decodeCloseRequests) { + // Passing the Close frame for further processing + val _ = ctx.fireChannelRead(close) + } else { + // Responding with Close immediately + val _ = ctx + .writeAndFlush(close) + .close() + } + case other => + val _ = ctx.fireChannelRead(other) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala new file mode 100644 index 0000000000..d9e7c75cd1 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala @@ -0,0 +1,38 @@ +package sttp.tapir.server.netty.internal.ws + +import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.handler.codec.http.websocketx._ +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import sttp.ws.WebSocketFrame + +object WebSocketFrameConverters { + + def getBytes(buf: ByteBuf): Array[Byte] = { + val bytes = new Array[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes + } + + def nettyFrameToFrame(nettyFrame: NettyWebSocketFrame): WebSocketFrame = { + nettyFrame match { + case text: TextWebSocketFrame => WebSocketFrame.Text(text.text, text.isFinalFragment, Some(text.rsv)) + case close: CloseWebSocketFrame => WebSocketFrame.Close(close.statusCode, close.reasonText) + case ping: PingWebSocketFrame => WebSocketFrame.Ping(getBytes(ping.content())) + case pong: PongWebSocketFrame => WebSocketFrame.Pong(getBytes(pong.content())) + case _ => WebSocketFrame.Binary(getBytes(nettyFrame.content()), nettyFrame.isFinalFragment, Some(nettyFrame.rsv)) + } + } + + def frameToNettyFrame(w: WebSocketFrame): NettyWebSocketFrame = w match { + case WebSocketFrame.Text(payload, finalFragment, rsvOpt) => + new TextWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), payload) + case WebSocketFrame.Close(statusCode, reasonText) => + new CloseWebSocketFrame(statusCode, reasonText) + case WebSocketFrame.Ping(payload) => + new PingWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Pong(payload) => + new PongWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) => + new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload)) + } +} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 8eb00c55bc..390899e21a 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -91,7 +91,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: unsafeRunAsync(runtime), channelGroup, isShuttingDown, - config.serverHeader + config.serverHeader, + config.isSsl ), eventLoopGroup, socketOverride diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 60a5a2b067..bf0f99fe26 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -4,13 +4,15 @@ import _root_.zio._ import _root_.zio.interop.reactivestreams._ import _root_.zio.stream.{Stream, ZStream} import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.zio.ZioStreams -import sttp.tapir.FileRange import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream +import io.netty.channel.ChannelHandlerContext private[zio] object ZioStreamCompatible { @@ -64,6 +66,13 @@ private[zio] object ZioStreamCompatible { override def emptyStream: streams.BinaryStream = ZStream.empty + + override def asWsProcessor[REQ, RESP]( + pipe: Stream[Throwable, REQ] => Stream[Throwable, RESP], + o: WebSocketBodyOutput[Stream[Throwable, REQ] => Stream[Throwable, RESP], REQ, RESP, ?, ZioStreams], + ctx: ChannelHandlerContext + ): Processor[WebSocketFrame, WebSocketFrame] = + throw new UnsupportedOperationException("TODO") } } } diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index 7f15eecf29..edf58cb4db 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -105,7 +105,13 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ - new ServerWebSocketTests(createServerTest, PekkoStreams) { + new ServerWebSocketTests( + createServerTest, + PekkoStreams, + autoPing = false, + failingPipe = true, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 3662a63fd3..7ef2dc432d 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -122,7 +122,13 @@ class PlayServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, PekkoStreams) { + new ServerWebSocketTests( + createServerTest, + PekkoStreams, + autoPing = false, + failingPipe = true, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 096b077f44..621feb18d3 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -117,7 +117,13 @@ class PlayServerTest extends TestSuite { new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, AkkaStreams) { + new ServerWebSocketTests( + createServerTest, + AkkaStreams, + autoPing = false, + failingPipe = true, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 378fd86ddf..7ec6a4a2bc 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -3,13 +3,16 @@ package sttp.tapir.server.tests import cats.effect.IO import cats.syntax.all._ import io.circe.generic.auto._ +import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.{Streams, WebSockets} import sttp.client3._ +import sttp.model.StatusCode import sttp.monad.MonadError import sttp.tapir._ import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ +import sttp.tapir.model.UnsupportedWebSocketFrameException import sttp.tapir.server.interceptor.CustomiseInterceptors import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor import sttp.tapir.server.tests.ServerMetricsTest._ @@ -17,12 +20,17 @@ import sttp.tapir.tests.Test import sttp.tapir.tests.data.Fruit import sttp.ws.{WebSocket, WebSocketFrame} +import scala.concurrent.duration._ + abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, S with WebSockets, OPTIONS, ROUTE], - val streams: S + val streams: S, + autoPing: Boolean, + failingPipe: Boolean, + handlePong: Boolean )(implicit m: MonadError[F] -) { +) extends EitherValues { import createServerTest._ def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] @@ -43,11 +51,28 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( _ <- ws.sendText("test2") m1 <- ws.receiveText() m2 <- ws.receiveText() - } yield List(m1, m2) + _ <- ws.close() + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.body shouldBe Right(List("echo: test1", "echo: test2", Left(WebSocketFrame.Close(1000, "normal closure"))))) + }, + testServer( + endpoint.in("elsewhere").out(stringBody), + "WS handshake to a non-existing endpoint" + )((_: Unit) => pureResult("hello".asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test") + m1 <- ws.receiveText() + } yield m1 }) .get(baseUri.scheme("ws")) .send(backend) - .map(_.body shouldBe Right(List("echo: test1", "echo: test2"))) + .map(_.code shouldBe StatusCode.NotFound) }, { val reqCounter = newRequestCounter[F] @@ -118,6 +143,60 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) ) }, + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(true) + ), + "pong on ping" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes())) + m1 <- ws.receive() + _ <- ws.sendText("test2") + m2 <- ws.receive() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map((r: Response[Either[String, List[WebSocketFrame]]]) => + assert( + r.body.value exists { + case WebSocketFrame.Pong(array) => array sameElements "test-ping-text".getBytes + case _ => false + }, + s"Missing Pong(test-ping-text) in ${r.body}" + ) + ) + }, + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + ), + "not pong on ping if disabled" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes())) + m1 <- ws.receiveText() + _ <- ws.sendText("test2") + m2 <- ws.receiveText() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map( + _.body shouldBe Right(List("echo: test1", "echo: test2")) + ) + }, testServer( endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), "empty client stream" @@ -130,18 +209,164 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( }, testServer( endpoint - .in(isWebSocket) + .in(query[String]("securityToken")) .errorOut(stringBody) .out(stringWs), - "non web-socket request" - )(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) => - basicRequest - .response(asString) - .get(baseUri.scheme("http")) - .send(backend) - .map(_.body shouldBe Left("Not a WS!")) + "switch to WS after a normal HTTP request" + )(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) { + (backend, baseUri) => + for { + response1 <- basicRequest + .response(asString) + .get(uri"$baseUri?securityToken=wrong".scheme("http")) + .send(backend) + response2 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testOk") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=correctToken".scheme("ws")) + .send(backend) + } yield { + response1.body shouldBe Left("Incorrect token!") + response2.body shouldBe Right("echo: testOk") + } + }, + testServer( + endpoint + .in(query[String]("securityToken")) + .errorOut(stringBody) + .out(stringWs), + "reject WS handshake, then accept a corrected one" + )(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) { + (backend, baseUri) => + for { + response1 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testWrong") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=wrong".scheme("ws")) + .send(backend) + response2 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testOk") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=correctToken".scheme("ws")) + .send(backend) + } yield { + response1.code shouldBe StatusCode.BadRequest + response2.body shouldBe Right("echo: testOk") + } } - ) + ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests + + val autoPingTests = + if (autoPing) + List( + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(Some((50.millis, WebSocketFrame.ping))) + ), + "auto ping" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- IO.sleep(150.millis) + _ <- ws.sendText("test2") + m1 <- ws.receive() + m2 <- ws.receive() + _ <- ws.sendText("test3") + m3 <- ws.receive() + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map((r: Response[Either[String, List[WebSocketFrame]]]) => + assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r") + ) + } + ) + else List.empty + + // Optional, because some backends don't handle exceptions in the pipe gracefully, they just swallow it silently and hang forever + val failingPipeTests = + if (failingPipe) + List( + testServer( + endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), + "failing pipe" + )((_: Unit) => + pureResult(functionToPipe[String, String] { + case "error-trigger" => throw new Exception("Boom!") + case msg => s"echo: $msg" + }.asRight[Unit]) + ) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.sendText("test2") + _ <- ws.sendText("error-trigger") + m1 <- ws.eitherClose(ws.receiveText()) + m2 <- ws.eitherClose(ws.receiveText()) + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { r => + val results = r.body.map(_.map(_.left.map(_.statusCode))).value + results.take(2) shouldBe + List(Right("echo: test1"), Right("echo: test2")) + val closeCode = results.last.left.value + assert(closeCode == 1000 || closeCode == 1011) // some servers respond with Close(normal), some with Close(error) + } + } + ) + else List.empty + + val handlePongTests = + if (handlePong) + List( + testServer( + { + implicit def textOrPongWebSocketFrame[A, CF <: CodecFormat](implicit + stringCodec: Codec[String, A, CF] + ): Codec[WebSocketFrame, A, CF] = + Codec // A custom codec to handle Pongs + .id[WebSocketFrame, CF](stringCodec.format, Schema.string) + .mapDecode { + case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p) + case WebSocketFrame.Pong(payload) => + stringCodec.decode(new String(payload)) + case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f)) + }(a => WebSocketFrame.text(stringCodec.encode(a))) + .schema(stringCodec.schema) - // TODO: tests for ping/pong (control frames handling) + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .ignorePong(false) + ) + }, + "not ignore pong" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Pong("test-pong-text".getBytes())) + m1 <- ws.receiveText() + _ <- ws.sendText("test2") + m2 <- ws.receiveText() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.body shouldBe Right(List("echo: test1", "echo: test-pong-text"))) + } + ) + else List.empty } diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index 9a96f67fd1..9dfd570fa7 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -37,7 +37,13 @@ class CatsVertxServerTest extends TestSuite { partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams.apply[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO]) { + new ServerWebSocketTests( + createServerTest, + Fs2Streams.apply[IO], + autoPing = false, + failingPipe = true, + handlePong = true + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty }.tests() diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index 8a7a8f57ce..bb6294924a 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -27,7 +27,7 @@ class VertxServerTest extends TestSuite { def drainVertx[T](source: ReadStream[T]): Future[Unit] = { val p = Promise[Unit]() // Handler for stream data - do nothing with the data - val dataHandler: Handler[T] = (_: T) => () + val dataHandler: Handler[T] = (_: T) => () // End handler - complete the promise when the stream ends val endHandler: Handler[Void] = (_: Void) => p.success(()) @@ -53,7 +53,13 @@ class VertxServerTest extends TestSuite { partContentTypeHeaderSupport = true, partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(VertxStreams)(drainVertx[Buffer]) ++ - (new ServerWebSocketTests(createServerTest, VertxStreams) { + (new ServerWebSocketTests( + createServerTest, + VertxStreams, + autoPing = false, + failingPipe = false, + handlePong = true + ) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() }).tests() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index 0a8c27f4e4..a8d79dcb5b 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -43,7 +43,13 @@ class ZioVertxServerTest extends TestSuite with OptionValues { partOtherHeaderSupport = false ).tests() ++ additionalTests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = true, + handlePong = true + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 828290d902..417e16a5af 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -266,7 +266,13 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = false, + handlePong = false + ) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty }.tests() ++