From 28455d2587cc34eb0ce3edbee23e71f4fd1d9091 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 20 Mar 2024 08:38:46 +0100 Subject: [PATCH 01/31] Initial implementation --- .../examples/WebSocketsNettyCatsServer.scala | 75 ++++++++ .../tapir/perf/netty/cats/NettyCats.scala | 26 ++- .../server/netty/cats/NettyCatsServer.scala | 5 +- .../cats/NettyCatsServerInterpreter.scala | 5 +- .../cats/internal/Fs2StreamCompatible.scala | 27 +-- .../cats/internal/NettyCatsRequestBody.scala | 4 +- .../internal/WebSocketPipeProcessor.scala | 94 ++++++++++ .../netty/cats/NettyCatsServerTest.scala | 6 +- .../cats/NettyCatsTestServerInterpreter.scala | 5 +- .../server/netty/loom/NettyIdServer.scala | 2 + .../sttp/tapir/server/netty/NettyConfig.scala | 13 +- .../server/netty/NettyFutureServer.scala | 2 + .../server/netty/NettyResponseContent.scala | 12 ++ .../netty/internal/NettyBootstrap.scala | 9 +- .../netty/internal/NettyServerHandler.scala | 11 ++ .../internal/NettyToStreamsResponseBody.scala | 32 +++- .../internal/ReactiveWebSocketHandler.scala | 161 ++++++++++++++++++ .../netty/internal/StreamCompatible.scala | 7 +- .../internal/WebSocketAutoPingHandler.scala | 45 +++++ .../WebSocketControlFrameHandler.scala | 33 ++++ .../internal/WebSocketFrameConverters.scala | 32 ++++ .../server/netty/zio/NettyZioServer.scala | 3 +- .../zio/internal/ZioStreamCompatible.scala | 11 +- .../server/tests/ServerWebSocketTests.scala | 6 +- 24 files changed, 582 insertions(+), 44 deletions(-) create mode 100644 examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala create mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala diff --git a/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala b/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala new file mode 100644 index 0000000000..50729f0c2e --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala @@ -0,0 +1,75 @@ +package sttp.tapir.examples + +import cats.effect.{IO, IOApp} +import sttp.client3._ +import sttp.model.StatusCode +import sttp.tapir.server.netty.cats.NettyCatsServer +import sttp.tapir.* +import scala.concurrent.duration._ +import sttp.capabilities.fs2.Fs2Streams +import sttp.ws.WebSocket +import sttp.client3.pekkohttp.PekkoHttpBackend +import scala.concurrent.Future + +object WebSocketsNettyCatsServer extends IOApp.Simple { + // One endpoint on GET /hello with query parameter `name` + val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = + endpoint.get.in("hello").in(query[String]("name")).out(stringBody) + + val wsEndpoint = + endpoint.get.in("ws").out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])) + + val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => + IO.pure(in => in.evalMap(str => IO.println(s"responding with ${str.toUpperCase}") >> IO.pure(str.toUpperCase()))) + ) + // Just returning passed name with `Hello, ` prepended + val helloWorldServerEndpoint = helloWorldEndpoint + .serverLogic(name => IO.pure[Either[Unit, String]](Right(s"Hello, $name!"))) + + private val declaredPort = 9090 + private val declaredHost = "localhost" + + // Creating handler for netty bootstrap + override def run = NettyCatsServer + .io() + .use { server => + for { + binding <- server + .port(declaredPort) + .host(declaredHost) + .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) + .start() + result <- IO + .fromFuture(IO.delay { + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") + import scala.concurrent.ExecutionContext.Implicits.global + def useWebSocket(ws: WebSocket[Future]): Future[Unit] = { + def send(i: Int) = ws.sendText(s"Hello $i!") + def receive() = ws.receiveText().map(t => println(s"Client RECEIVED: $t")) + for { + _ <- send(1) + _ <- receive() + _ <- send(2) + _ <- send(3) + _ <- receive() + } yield () + } + val backend = PekkoHttpBackend() + + val url = uri"ws://$host:$port/ws" + val allGood = uri"http://$host:$port/hello?name=Netty" + basicRequest.response(asStringAlways).get(allGood).send(backend).map(r => println(r.body)) + .flatMap { _ => + basicRequest + .response(asWebSocket(useWebSocket)) + .get(url) + .send(backend) + } + .andThen { case _ => backend.close() } + }) + .guarantee(binding.stop()) + } yield result + } +} diff --git a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala index 85d3ac7c2b..10639eb7a9 100644 --- a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala +++ b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala @@ -3,16 +3,36 @@ package sttp.tapir.perf.netty.cats import cats.effect.IO import cats.effect.kernel.Resource import cats.effect.std.Dispatcher +import fs2.Stream +import sttp.tapir.{CodecFormat, webSocketBody} import sttp.tapir.perf.Common._ import sttp.tapir.perf.apis._ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.cats.NettyCatsServer import sttp.tapir.server.netty.cats.NettyCatsServerOptions +import sttp.ws.WebSocketFrame +import sttp.capabilities.fs2.Fs2Streams -object Tapir extends Endpoints +import scala.concurrent.duration._ -object NettyCats { +object Tapir extends Endpoints { + val wsResponseStream = Stream.fixedRate[IO](WebSocketSingleResponseLag, dampen = false) + val wsEndpoint = wsBaseEndpoint + .out( + webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](Fs2Streams[IO]) + .concatenateFragmentedFrames(false) + .autoPongOnPing(false) + .ignorePong(true) + .autoPing(None) + ) +} +object NettyCats { + val wsServerEndpoint = Tapir.wsEndpoint.serverLogicSuccess(_ => + IO.pure { (in: Stream[IO, Long]) => + Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(())) + } + ) def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = { val declaredPort = Port val declaredHost = "0.0.0.0" @@ -25,7 +45,7 @@ object NettyCats { server .port(declaredPort) .host(declaredHost) - .addEndpoints(endpoints) + .addEndpoints(wsServerEndpoint :: endpoints) .start() )(binding => binding.stop()) } yield ()).allocated.map(_._2) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 1b7a449326..abe2b64be4 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -23,12 +23,14 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future import scala.concurrent.duration._ +import sttp.capabilities.WebSockets +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( @@ -75,6 +77,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyBootstrap( config, new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 9e1c4722c0..8f65e39a73 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -14,12 +14,13 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.capabilities.WebSockets trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] def nettyServerOptions: NettyCatsServerOptions[F] - def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = { + def toRoute(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): Route[F] = { implicit val monad: MonadError[F] = new CatsMonadError[F] val runAsync = new RunAsync[F] { @@ -31,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val createFile = nettyServerOptions.createFile val deleteFile = nettyServerOptions.deleteFile - val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( + val serverInterpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 3849fb4a19..28167f4422 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -1,22 +1,19 @@ package sttp.tapir.server.netty.cats.internal +import cats.effect.kernel.{Async, Sync} +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.io.file.{Files, Flags, Path} +import fs2.{Chunk, Pipe} import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher -import sttp.tapir.FileRange +import org.reactivestreams.{Processor, Publisher} +import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream -import cats.effect.std.Dispatcher -import sttp.capabilities.fs2.Fs2Streams -import fs2.io.file.Path -import fs2.io.file.Files -import cats.effect.kernel.Async -import fs2.io.file.Flags -import fs2.interop.reactivestreams.StreamUnicastPublisher -import cats.effect.kernel.Sync -import fs2.Chunk -import fs2.interop.reactivestreams.StreamSubscriber object Fs2StreamCompatible { @@ -68,6 +65,12 @@ object Fs2StreamCompatible { override def emptyStream: streams.BinaryStream = fs2.Stream.empty + override def asWsProcessor[REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] + ): Processor[WebSocketFrame, WebSocketFrame] = + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o) + private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index 4dfe6c22fb..98b0742b74 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -12,6 +12,7 @@ import sttp.tapir.TapirFile import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} +import sttp.capabilities.WebSockets private[cats] class NettyCatsRequestBody[F[_]: Async]( val createFile: ServerRequest => F[TapirFile], @@ -24,7 +25,8 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = - (toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream]) + (toStream(serverRequest, maxBytes) + .asInstanceOf[streamCompatible.streams.BinaryStream]) .through( Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala new file mode 100644 index 0000000000..518db27e65 --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -0,0 +1,94 @@ +package sttp.tapir.server.netty.cats.internal + +import cats.Applicative +import cats.effect.kernel.Async +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.{Pipe, Stream} +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription} +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.tapir.server.netty.internal.WebSocketFrameConverters._ +import sttp.tapir.{DecodeResult, WebSocketBodyOutput} +import sttp.ws.WebSocketFrame + +import scala.concurrent.ExecutionContext.Implicits +import scala.concurrent.Promise +import scala.util.{Failure, Success} + +class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + dispatcher: Dispatcher[F], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] +) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { + private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ + private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() + private var subscription: Subscription = _ + + override def onSubscribe(s: Subscription): Unit = { + subscriber = dispatcher.unsafeRunSync( + StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) + ) + subscription = s + val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) + val sttpFrames = in.map { f => + val sttpFrame = nettyFrameToFrame(f) + f.release() + sttpFrame + } + val stream: Stream[F, NettyWebSocketFrame] = + optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + .map(f => + o.requests.decode(f) match { + case x: DecodeResult.Value[REQ] => x.v + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + } + ) + .through(pipe) + .map(r => frameToNettyFrame(o.responses.encode(r))) + .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) + + subscriber.sub.onSubscribe(s) + publisher.success(StreamUnicastPublisher(stream, dispatcher)) + } + + override def onNext(t: NettyWebSocketFrame): Unit = { + subscriber.sub.onNext(t) + } + + override def onError(t: Throwable): Unit = { + subscriber.sub.onError(t) + } + + override def onComplete(): Unit = { + subscriber.sub.onComplete() + } + + override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { + publisher.future.onComplete { + case Success(p) => + p.subscribe(s) + case Failure(ex) => + subscriber.sub.onError(ex) + subscription.cancel + }(Implicits.global) + } + + private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = + if (doConcatenate) { + type Accumulator = Option[Either[Array[Byte], String]] + + s.mapAccumulate(None: Accumulator) { + case (None, f: WebSocketFrame.Ping) => (None, Some(f)) + case (None, f: WebSocketFrame.Pong) => (None, Some(f)) + case (None, f: WebSocketFrame.Close) => (None, Some(f)) + case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + }.collect { case (_, Some(f)) => f } + } else s +} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index e7a1280719..e3252a25a1 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -40,7 +40,11 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ - new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() + new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ + new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { + override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) + override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty + }.tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 61be7e6f4d..9334504a2f 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -4,6 +4,7 @@ import cats.data.NonEmptyList import cats.effect.std.Dispatcher import cats.effect.{IO, Resource} import io.netty.channel.nio.NioEventLoopGroup +import sttp.capabilities.WebSockets import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter @@ -12,8 +13,8 @@ import sttp.capabilities.fs2.Fs2Streams import scala.concurrent.duration.FiniteDuration class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) - extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { - override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = { + extends TestServerInterpreter[IO, Fs2Streams[IO] with WebSockets, NettyCatsServerOptions[IO], Route[IO]] { + override def route(es: List[ServerEndpoint[Fs2Streams[IO] with WebSockets, IO]], interceptors: Interceptors): Route[IO] = { val serverOptions: NettyCatsServerOptions[IO] = interceptors( NettyCatsServerOptions.customiseInterceptors[IO](dispatcher) ).options diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 46918fc8fd..2de073f763 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -25,6 +25,7 @@ import scala.concurrent.Future import scala.concurrent.Promise import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) { private val executor = Executors.newVirtualThreadPerTaskExecutor() @@ -100,6 +101,7 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, isShuttingDown, config.serverHeader ), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunF, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 07e43954c5..1d6d66144d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -61,7 +61,7 @@ case class NettyConfig( sslContext: Option[SslContext], eventLoopConfig: EventLoopConfig, socketConfig: NettySocketConfig, - initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit, + initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit, gracefulShutdownTimeout: Option[FiniteDuration], serverHeader: Option[String] ) { @@ -96,7 +96,7 @@ case class NettyConfig( def eventLoopConfig(elc: EventLoopConfig): NettyConfig = copy(eventLoopConfig = elc) def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) - def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) + def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) def noGracefulShutdown = copy(gracefulShutdownTimeout = None) @@ -120,14 +120,15 @@ object NettyConfig { sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipeline(cfg)(_, _), + initPipeline = cfg => defaultInitPipeline(cfg)(_, _, _), serverHeader = Some(s"tapir/${buildinfo.BuildInfo.version}") ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler, wsHandler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpStreamsServerHandler()) + pipeline.addLast("serverCodecHandler", new HttpServerCodec()) + pipeline.addLast("streamsHandler", new HttpStreamsServerHandler()) + pipeline.addLast("wsHandler", wsHandler) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index f510e8e85b..2c972c3226 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -16,6 +16,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ExecutionContext, Future, blocking} +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext @@ -72,6 +73,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyBootstrap( config, new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index b17335f60b..d3aea1dda6 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -5,6 +5,10 @@ import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.HttpContent import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.reactivestreams.Publisher +import org.reactivestreams.Processor +import io.netty.handler.codec.http.websocketx.WebSocketFrame +import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} +import scala.concurrent.duration.FiniteDuration sealed trait NettyResponseContent { def channelPromise: ChannelPromise @@ -17,4 +21,12 @@ object NettyResponseContent { final case class ChunkedFileNettyResponseContent(channelPromise: ChannelPromise, chunkedFile: ChunkedFile) extends NettyResponseContent final case class ReactivePublisherNettyResponseContent(channelPromise: ChannelPromise, publisher: Publisher[HttpContent]) extends NettyResponseContent + final case class ReactiveWebSocketProcessorNettyResponseContent( + channelPromise: ChannelPromise, + processor: Processor[WebSocketFrame, WebSocketFrame], + ignorePong: Boolean, + autoPongOnPing: Boolean, + decodeCloseRequests: Boolean, + autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] + ) extends NettyResponseContent } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 71d57519a8..68835c8f07 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -12,6 +12,7 @@ object NettyBootstrap { def apply[F[_]]( nettyConfig: NettyConfig, handler: => NettyServerHandler[F], + wsHandler: => ReactiveWebSocketHandler[F], eventLoopGroup: EventLoopGroup, overrideSocketAddress: Option[SocketAddress] ): ChannelFuture = { @@ -27,8 +28,12 @@ object NettyBootstrap { nettyConfig.requestTimeout match { case Some(requestTimeout) => - nettyConfigBuilder(ch.pipeline().addLast(new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), handler) - case None => nettyConfigBuilder(ch.pipeline(), handler) + nettyConfigBuilder( + ch.pipeline().addLast("readTimeoutHandler", new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), + handler, + wsHandler + ) + case None => nettyConfigBuilder(ch.pipeline(), handler, wsHandler) } connectionCounterOpt.map(counter => { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index c3333a7ff4..3da92d8f56 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,6 +4,7 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher @@ -25,6 +26,8 @@ import scala.collection.mutable.{Queue => MutableQueue} import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success} +import sttp.tapir.EndpointInput.AuthType.Http +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent /** @param unsafeRunAsync * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => @@ -211,6 +214,12 @@ class NettyServerHandler[F[_]]( }) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + }, + wsHandler = (channelPromise) => { + logger.error("Unexpected WebSocket processor response received in NettyServerHandler, it should be handled only in the ReactiveWebSocketHandler") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res) }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( @@ -234,6 +243,7 @@ class NettyServerHandler[F[_]]( chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, + wsHandler: ChannelPromise => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -245,6 +255,7 @@ class NettyServerHandler[F[_]]( case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r.channelPromise) } } case None => noBodyHandler() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 0f335a7b14..ae94a180b3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -12,12 +12,14 @@ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.nio.ByteBuffer import java.nio.charset.Charset +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent -/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries - * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. - * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. +/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries like fs2 + * or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. Other kinds + * of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. */ -private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { +private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) + extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams @@ -37,7 +39,10 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => - new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None)) + new ReactivePublisherNettyResponseContent( + ctx.newPromise(), + streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None) + ) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => @@ -47,7 +52,8 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl ) case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } @@ -68,5 +74,17 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl override def fromWebSocketPipe[REQ, RESP]( pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S] - ): NettyResponse = throw new UnsupportedOperationException + ): NettyResponse = (ctx: ChannelHandlerContext) => { + new ReactiveWebSocketProcessorNettyResponseContent( + ctx.newPromise(), + streamCompatible.asWsProcessor( + pipe.asInstanceOf[streamCompatible.streams.Pipe[REQ, RESP]], + o.asInstanceOf[WebSocketBodyOutput[streamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]] + ), + ignorePong = o.ignorePong, + autoPongOnPing = o.autoPongOnPing, + decodeCloseRequests = o.decodeCloseRequests, + autoPing = o.autoPing + ) + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala new file mode 100644 index 0000000000..7fc695ab78 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -0,0 +1,161 @@ +package sttp.tapir.server.netty.internal + +import io.netty.channel.group.ChannelGroup +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory +import io.netty.handler.codec.http._ +import io.netty.util.ReferenceCountUtil +import org.playframework.netty.http.DefaultWebSocketHttpResponse +import org.slf4j.LoggerFactory +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.server.model.ServerResponse +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal +import scala.util.{Failure, Success} + +/** Handles a WS handshake and initiates the communication by calling Tapir interpreter to get a Pipe, then sends that Pipe to the rest of + * the processing pipeline and removes itself from the pipeline. + */ +class ReactiveWebSocketHandler[F[_]]( + route: Route[F], + channelGroup: ChannelGroup, + unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), + isSsl: Boolean +)(implicit m: MonadError[F]) + extends ChannelInboundHandlerAdapter { + + // By using the Netty event loop assigned to this channel we get two benefits: + // 1. We can avoid the necessary hopping around of threads since Netty pipelines will + // only pass events up and down from within the event loop to which it is assigned. + // That means calls to ctx.read(), and ctx.write(..), would have to be trampolined otherwise. + // 2. We get serialization of execution: the EventLoop is a serial execution queue so + // we can rest easy knowing that no two events will be executed in parallel. + private[this] var eventLoopContext: ExecutionContext = _ + + private val logger = LoggerFactory.getLogger(getClass.getName) + + def isWsHandshake(req: HttpRequest): Boolean = + "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && + "Upgrade".equalsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION)) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = + if (ctx.channel.isActive) { + initHandler(ctx) + } + override def channelActive(ctx: ChannelHandlerContext): Unit = { + channelGroup.add(ctx.channel) + initHandler(ctx) + } + + private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { + if (eventLoopContext == null) + eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logger.error("Error while processing the request", cause) + } + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + def writeError500(req: HttpRequest, reason: Throwable): Unit = { + logger.error("Error while processing the request", reason) + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + val _ = ctx.writeAndFlush(res) + } + msg match { + case req: FullHttpRequest if isWsHandshake(req) => + ctx.pipeline().remove(this) + ctx.pipeline().remove("readTimeoutHandler") + ReferenceCountUtil.release(msg) + val (runningFuture, _) = unsafeRunAsync { () => + route(NettyServerRequest(req.retain())) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + } + + val _ = runningFuture.transform { + case Success(serverResponse) => + try { + serverResponse.body match { + case Some(function) => { + val content = function(ctx) + content match { + case r: ReactiveWebSocketProcessorNettyResponseContent => { + ctx + .pipeline() + .addAfter( + "serverCodecHandler", + "wsControlFrameHandler", + new NettyControlFrameHandler( + ignorePong = r.ignorePong, + autoPongOnPing = r.autoPongOnPing, + decodeCloseRequests = r.decodeCloseRequests + ) + ) + r.autoPing.foreach { case (interval, pingMsg) => + ctx.pipeline().addFirst("wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) + } + r.channelPromise.setSuccess() + val _ = ctx.writeAndFlush( + // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler + // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) + new DefaultWebSocketHttpResponse( + req.protocolVersion(), + HttpResponseStatus.valueOf(200), + r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler + new WebSocketServerHandshakerFactory(wsUrl(req), null, false) + ) + ) + } + case otherContent => + logger.error(s"Unexpected response content: $otherContent") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + otherContent.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) + val _ = ctx.writeAndFlush(res) + } + } + case None => + logger.error("Missing response body, expected WebSocketProcessorNettyResponseContent") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res) + } + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } finally { + val _ = req.release() + } + case Failure(NonFatal(ex)) => + try { + writeError500(req, ex) + Failure(ex) + } finally { + val _ = req.release() + } + case Failure(fatalException) => Failure(fatalException) + }(eventLoopContext) + + case other => + // not a WS handshake, from now on process messages as normal HTTP requests in this channel + ctx.pipeline.remove(this) + val _ = ctx.fireChannelRead(other) + } + } + + // Only ancient WS protocol versions will use this in the response header. + private def wsUrl(req: FullHttpRequest): String = { + val scheme = if (isSsl) "wss" else "ws" + s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 6d5da177bd..b973aef7eb 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -1,9 +1,10 @@ package sttp.tapir.server.netty.internal import io.netty.handler.codec.http.HttpContent -import org.reactivestreams.Publisher +import io.netty.handler.codec.http.websocketx.WebSocketFrame +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.Streams -import sttp.tapir.FileRange +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream @@ -27,4 +28,6 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = asPublisher(fromInputStream(is, chunkSize, length)) + + def asWsProcessor[REQ, RESP](pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S]): Processor[WebSocketFrame, WebSocketFrame] } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala new file mode 100644 index 0000000000..62addee363 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala @@ -0,0 +1,45 @@ +package sttp.tapir.server.netty.internal + +import io.netty.buffer.Unpooled +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame +import org.slf4j.LoggerFactory + +import java.util.concurrent.{ScheduledFuture, TimeUnit} +import scala.concurrent.duration.FiniteDuration + +/** If auto ping is enabled for an endpoint, this handler will be plugged into the pipeline. Its responsibility is to manage start and stop + * of the ping scheduler. + * TODO: should we include logic for closing the channel and reporting an error if some kind of ping timeout is exceeded? + * @param pingInterval + * time interval to be used between sending pings to the client. + * @param frame + * desired content of the Ping frame, as specified in the Tapir endpoint output. + */ +class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebSocketFrame.Ping) extends ChannelInboundHandlerAdapter { + val nettyFrame = new PingWebSocketFrame(Unpooled.copiedBuffer(frame.payload)) + private var pingTask: ScheduledFuture[_] = _ + + private val logger = LoggerFactory.getLogger(getClass.getName) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = { + if (ctx.channel.isActive) { + logger.debug(s"STARTING WebSocket Ping scheduler for channel ${ctx.channel}, interval = $pingInterval") + val sendPing: Runnable = new Runnable { + override def run(): Unit = { + logger.trace(s"Sending PING WebSocket frame for channel ${ctx.channel}") + val _ = ctx.writeAndFlush(nettyFrame.retain()) + } + } + pingTask = ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) + } + } + + override def channelInactive(ctx: ChannelHandlerContext): Unit = { + super.channelInactive(ctx) + logger.debug(s"STOPPING WebSocket Ping scheduler for channel ${ctx.channel}") + if (pingTask != null) { + val _ = pingTask.cancel(false) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala new file mode 100644 index 0000000000..098e76bcfe --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.netty.internal + +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame + +/** + * Handles Ping, Pong, and Close frames for WebSockets. + */ +class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean) + extends ChannelInboundHandlerAdapter { + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + msg match { + case ping: PingWebSocketFrame if autoPongOnPing => + val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + case pong: PongWebSocketFrame if !ignorePong => + val _ = ctx.fireChannelRead(pong) + case close: CloseWebSocketFrame => + if (decodeCloseRequests) { + // Passing the Close frame for further processing + val _ = ctx.fireChannelRead(close) + } else { + // Responding with Close immediately + val _ = ctx.writeAndFlush(close) + } + case other => + val _ = ctx.fireChannelRead(other) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala new file mode 100644 index 0000000000..2d5bc06615 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala @@ -0,0 +1,32 @@ +package sttp.tapir.server.netty.internal + +import io.netty.handler.codec.http.websocketx._ +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import sttp.ws.WebSocketFrame +import io.netty.buffer.Unpooled + +object WebSocketFrameConverters { + + def nettyFrameToFrame(nettyFrame: NettyWebSocketFrame): WebSocketFrame = { + nettyFrame match { + case text: TextWebSocketFrame => WebSocketFrame.Text(text.text, text.isFinalFragment, Some(text.rsv)) + case close: CloseWebSocketFrame => WebSocketFrame.Close(close.statusCode, close.reasonText) + case ping: PingWebSocketFrame => WebSocketFrame.Ping(ping.content().nioBuffer().array()) + case pong: PongWebSocketFrame => WebSocketFrame.Pong(pong.content().nioBuffer().array()) + case _ => WebSocketFrame.Binary(nettyFrame.content().nioBuffer().array(), nettyFrame.isFinalFragment, Some(nettyFrame.rsv)) + } + } + + def frameToNettyFrame(w: WebSocketFrame): NettyWebSocketFrame = w match { + case WebSocketFrame.Text(payload, finalFragment, rsvOpt) => + new TextWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), payload) + case WebSocketFrame.Close(statusCode, reasonText) => + new CloseWebSocketFrame(statusCode, reasonText) + case WebSocketFrame.Ping(payload) => + new PingWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Pong(payload) => + new PongWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) => + new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload)) + } +} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 8eb00c55bc..da8c9082d3 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -7,7 +7,7 @@ import io.netty.util.concurrent.DefaultEventExecutor import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler, ReactiveWebSocketHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} @@ -93,6 +93,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: isShuttingDown, config.serverHeader ), + new ReactiveWebSocketHandler[RIO[R, *]](route, channelGroup, unsafeRunAsync(runtime), config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 60a5a2b067..14ad954784 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -4,11 +4,12 @@ import _root_.zio._ import _root_.zio.interop.reactivestreams._ import _root_.zio.stream.{Stream, ZStream} import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.zio.ZioStreams -import sttp.tapir.FileRange import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream @@ -64,6 +65,12 @@ private[zio] object ZioStreamCompatible { override def emptyStream: streams.BinaryStream = ZStream.empty + + override def asWsProcessor[REQ, RESP]( + pipe: Stream[Throwable, REQ] => Stream[Throwable, RESP], + o: WebSocketBodyOutput[Stream[Throwable, REQ] => Stream[Throwable, RESP], REQ, RESP, ?, ZioStreams] + ): Processor[WebSocketFrame, WebSocketFrame] = + throw new UnsupportedOperationException("TODO") } } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 378fd86ddf..05cfaad8bb 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -43,11 +43,13 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( _ <- ws.sendText("test2") m1 <- ws.receiveText() m2 <- ws.receiveText() - } yield List(m1, m2) + _ <- ws.close() + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) }) .get(baseUri.scheme("ws")) .send(backend) - .map(_.body shouldBe Right(List("echo: test1", "echo: test2"))) + .map(_.body shouldBe Right(List("echo: test1", "echo: test2", Left(WebSocketFrame.Close(1000, "normal closure"))))) }, { val reqCounter = newRequestCounter[F] From 5f616714657e6495c0d481162a7aa6d06104c69c Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 25 Mar 2024 08:28:30 +0100 Subject: [PATCH 02/31] Add comments --- .../netty/cats/internal/WebSocketPipeProcessor.scala | 7 +++++++ .../server/netty/internal/ReactiveWebSocketHandler.scala | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 518db27e65..80f2960d4d 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -17,6 +17,8 @@ import scala.concurrent.ExecutionContext.Implicits import scala.concurrent.Promise import scala.util.{Failure, Success} +/** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. + */ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( pipe: Pipe[F, REQ, RESP], dispatcher: Dispatcher[F], @@ -27,7 +29,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( private var subscription: Subscription = _ override def onSubscribe(s: Subscription): Unit = { + // Not really that unsafe. Subscriber creation doesn't do any IO, only initializes an AtomicReference in an initial state. subscriber = dispatcher.unsafeRunSync( + // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) ) subscription = s @@ -49,7 +53,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( .map(r => frameToNettyFrame(o.responses.encode(r))) .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) + // Trigger listening for WS frames in the underlying fs2 StreamSubscribber subscriber.sub.onSubscribe(s) + // Signal that a Publisher is ready to send result frames publisher.success(StreamUnicastPublisher(stream, dispatcher)) } @@ -66,6 +72,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( } override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { + // A subscriber may come to read from our internal Publisher. It has to wait for the Publisher to be initialized. publisher.future.onComplete { case Success(p) => p.subscribe(s) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index 7fc695ab78..2b82684192 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -99,9 +99,10 @@ class ReactiveWebSocketHandler[F[_]]( decodeCloseRequests = r.decodeCloseRequests ) ) - r.autoPing.foreach { case (interval, pingMsg) => + r.autoPing.foreach { case (interval, pingMsg) => ctx.pipeline().addFirst("wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) } + // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly r.channelPromise.setSuccess() val _ = ctx.writeAndFlush( // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler From 6b92093bd1b454efe0f843be5188f2f62ebf2396 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 25 Mar 2024 09:58:23 +0100 Subject: [PATCH 03/31] Remove TODO --- .../server/netty/internal/WebSocketAutoPingHandler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala index 62addee363..189f6daff0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala @@ -10,7 +10,6 @@ import scala.concurrent.duration.FiniteDuration /** If auto ping is enabled for an endpoint, this handler will be plugged into the pipeline. Its responsibility is to manage start and stop * of the ping scheduler. - * TODO: should we include logic for closing the channel and reporting an error if some kind of ping timeout is exceeded? * @param pingInterval * time interval to be used between sending pings to the client. * @param frame @@ -31,7 +30,8 @@ class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebS val _ = ctx.writeAndFlush(nettyFrame.retain()) } } - pingTask = ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) + pingTask = + ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) } } From ecc0f98944061f0a1e268d1bcc0922fca9627898 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 25 Mar 2024 15:40:48 +0100 Subject: [PATCH 04/31] Fix cancellation and error handling --- .../cats/internal/Fs2StreamCompatible.scala | 20 ++++++++-- .../internal/WebSocketPipeProcessor.scala | 38 +++++++++++++++++-- .../internal/NettyToStreamsResponseBody.scala | 3 +- .../netty/internal/StreamCompatible.scala | 14 ++++--- .../server/tests/ServerWebSocketTests.scala | 30 ++++++++++++++- 5 files changed, 90 insertions(+), 15 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 28167f4422..4b97b733fb 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -6,7 +6,8 @@ import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} import fs2.io.file.{Files, Flags, Path} import fs2.{Chunk, Pipe} import io.netty.buffer.Unpooled -import io.netty.handler.codec.http.websocketx.WebSocketFrame +import io.netty.channel.{ChannelFuture, ChannelHandlerContext} +import io.netty.handler.codec.http.websocketx._ import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.fs2.Fs2Streams @@ -67,9 +68,20 @@ object Fs2StreamCompatible { override def asWsProcessor[REQ, RESP]( pipe: Pipe[F, REQ, RESP], - o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] - ): Processor[WebSocketFrame, WebSocketFrame] = - new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o) + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], + ctx: ChannelHandlerContext + ): Processor[WebSocketFrame, WebSocketFrame] = { + val onCancelPromise = ctx.newPromise() + onCancelPromise.addListener((f: ChannelFuture) => { + // A special callback that has to be used when a SteramSubscription cancels. + // This can happen in case of errors in the pipeline which are not signalled correctly, + // like throwing exceptions directly. + // Without explicit Close frame a client may hang on waiting and not knowing about closed channel. + if (f.isSuccess) + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Canceled")) + }) + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, onCancelPromise) + } private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = fs2.io.readInputStream( diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 80f2960d4d..202f3847f2 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -1,12 +1,15 @@ package sttp.tapir.server.netty.cats.internal import cats.Applicative -import cats.effect.kernel.Async +import cats.effect.kernel.{Async, Sync} import cats.effect.std.Dispatcher +import cats.syntax.all._ import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} import fs2.{Pipe, Stream} +import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription} +import org.slf4j.LoggerFactory import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.model.WebSocketFrameDecodeFailure import sttp.tapir.server.netty.internal.WebSocketFrameConverters._ @@ -15,6 +18,7 @@ import sttp.ws.WebSocketFrame import scala.concurrent.ExecutionContext.Implicits import scala.concurrent.Promise +import scala.util.control.NonFatal import scala.util.{Failure, Success} /** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. @@ -22,19 +26,22 @@ import scala.util.{Failure, Success} class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( pipe: Pipe[F, REQ, RESP], dispatcher: Dispatcher[F], - o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], + onCancel: ChannelPromise ) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() private var subscription: Subscription = _ + private val logger = LoggerFactory.getLogger(getClass.getName) + override def onSubscribe(s: Subscription): Unit = { // Not really that unsafe. Subscriber creation doesn't do any IO, only initializes an AtomicReference in an initial state. subscriber = dispatcher.unsafeRunSync( // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) ) - subscription = s + subscription = new ChannelAwareSubscription(s, onCancel) val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) val sttpFrames = in.map { f => val sttpFrame = nettyFrameToFrame(f) @@ -51,10 +58,13 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( ) .through(pipe) .map(r => frameToNettyFrame(o.responses.encode(r))) + .onError { case NonFatal(t) => + Stream.eval(Sync[F].delay(logger.error("Error occured in WebSocket channel", t))) + } .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) // Trigger listening for WS frames in the underlying fs2 StreamSubscribber - subscriber.sub.onSubscribe(s) + subscriber.sub.onSubscribe(subscription) // Signal that a Publisher is ready to send result frames publisher.success(StreamUnicastPublisher(stream, dispatcher)) } @@ -65,10 +75,12 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( override def onError(t: Throwable): Unit = { subscriber.sub.onError(t) + val _ = onCancel.cancel(true) } override def onComplete(): Unit = { subscriber.sub.onComplete() + val _ = onCancel.cancel(true) } override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { @@ -99,3 +111,21 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( }.collect { case (_, Some(f)) => f } } else s } + +/** This wrapped is needed to intercept the logic of StreamSubscription which calls cancel() in case of fatally failing streams. This makes + * errors get swallowed, so we replace delegate.onCancel() with our own onCancel callback that would trigger custom handling logging in + * Netty. Additionally, the stream will fail properly and any errors from the pipeline will be logged. + * + * @param delegate + * a channel subscription which we don't want to notify about cancelation. + * @param onCancel + * our custom cancellation callback. + */ +class ChannelAwareSubscription(delegate: Subscription, onCancel: ChannelPromise) extends Subscription { + override def cancel(): Unit = { + val _ = onCancel.setSuccess() + } + override def request(n: Long): Unit = { + delegate.request(n) + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index ae94a180b3..998efd390d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -79,7 +79,8 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl ctx.newPromise(), streamCompatible.asWsProcessor( pipe.asInstanceOf[streamCompatible.streams.Pipe[REQ, RESP]], - o.asInstanceOf[WebSocketBodyOutput[streamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]] + o.asInstanceOf[WebSocketBodyOutput[streamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]], + ctx ), ignorePong = o.ignorePong, autoPongOnPing = o.autoPongOnPing, diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index b973aef7eb..01f49586a6 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -7,11 +7,11 @@ import sttp.capabilities.Streams import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream +import io.netty.channel.ChannelHandlerContext -/** - * Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. - * This includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. - * We also need implementation of a failed (errored) stream, as well as an empty stream (for handling empty requests). +/** Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. This + * includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. We also need implementation of a + * failed (errored) stream, as well as an empty stream (for handling empty requests). */ private[netty] trait StreamCompatible[S <: Streams[S]] { val streams: S @@ -29,5 +29,9 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = asPublisher(fromInputStream(is, chunkSize, length)) - def asWsProcessor[REQ, RESP](pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S]): Processor[WebSocketFrame, WebSocketFrame] + def asWsProcessor[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S], + ctx: ChannelHandlerContext + ): Processor[WebSocketFrame, WebSocketFrame] } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 05cfaad8bb..91530fdfc5 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -44,7 +44,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( m1 <- ws.receiveText() m2 <- ws.receiveText() _ <- ws.close() - m3 <- ws.eitherClose(ws.receiveText()) + m3 <- ws.eitherClose(ws.receiveText()) } yield List(m1, m2, m3) }) .get(baseUri.scheme("ws")) @@ -120,6 +120,34 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) ) }, + testServer( + endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), + "failing pipe" + )((_: Unit) => + pureResult(functionToPipe[String, String] { + case "error-trigger" => throw new Exception("Boom!") + case msg => s"echo: $msg" + }.asRight[Unit]) + ) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.sendText("test2") + _ <- ws.sendText("error-trigger") + m1 <- ws.eitherClose(ws.receiveText()) + m2 <- ws.eitherClose(ws.receiveText()) + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map( + _.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right( + List(Right("echo: test1"), Right("echo: test2"), Left(WebSocketFrame.close.statusCode)) + ) + ) + }, testServer( endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), "empty client stream" From 48ed4f7a88aac24b47d9347077182edfea564d57 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Mon, 25 Mar 2024 21:51:22 +0100 Subject: [PATCH 05/31] A few more tweaks to cancellation and error handling --- .../cats/internal/Fs2StreamCompatible.scala | 8 +++- .../internal/WebSocketPipeProcessor.scala | 39 +++++++++++-------- .../server/tests/ServerWebSocketTests.scala | 2 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 4b97b733fb..e74794a1e0 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -73,12 +73,16 @@ object Fs2StreamCompatible { ): Processor[WebSocketFrame, WebSocketFrame] = { val onCancelPromise = ctx.newPromise() onCancelPromise.addListener((f: ChannelFuture) => { - // A special callback that has to be used when a SteramSubscription cancels. + // A special callback that has to be used when a SteramSubscription cancels or fails. // This can happen in case of errors in the pipeline which are not signalled correctly, // like throwing exceptions directly. // Without explicit Close frame a client may hang on waiting and not knowing about closed channel. - if (f.isSuccess) + if (f.isCancelled) { val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Canceled")) + } + else if (!f.isSuccess) { + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Error")) + } }) new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, onCancelPromise) } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 202f3847f2..732fef8f56 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.netty.cats.internal import cats.Applicative import cats.effect.kernel.{Async, Sync} +import cats.effect.kernel.Resource.ExitCase import cats.effect.std.Dispatcher import cats.syntax.all._ import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} @@ -18,7 +19,6 @@ import sttp.ws.WebSocketFrame import scala.concurrent.ExecutionContext.Implicits import scala.concurrent.Promise -import scala.util.control.NonFatal import scala.util.{Failure, Success} /** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. @@ -27,7 +27,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( pipe: Pipe[F, REQ, RESP], dispatcher: Dispatcher[F], o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], - onCancel: ChannelPromise + wsCompletedPromise: ChannelPromise ) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() @@ -41,7 +41,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) ) - subscription = new ChannelAwareSubscription(s, onCancel) + subscription = new NonCancelingSubscription(s) val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) val sttpFrames = in.map { f => val sttpFrame = nettyFrameToFrame(f) @@ -58,8 +58,13 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( ) .through(pipe) .map(r => frameToNettyFrame(o.responses.encode(r))) - .onError { case NonFatal(t) => - Stream.eval(Sync[F].delay(logger.error("Error occured in WebSocket channel", t))) + .onFinalizeCaseWeak { + case ExitCase.Succeeded => + Sync[F].delay { val _ = wsCompletedPromise.setSuccess() } + case ExitCase.Errored(t) => + Sync[F].delay(wsCompletedPromise.setFailure(t)) >> Sync[F].delay(logger.error("Error occured in WebSocket channel", t)) + case ExitCase.Canceled => + Sync[F].delay { val _ = wsCompletedPromise.cancel(true) } } .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) @@ -75,12 +80,16 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( override def onError(t: Throwable): Unit = { subscriber.sub.onError(t) - val _ = onCancel.cancel(true) + if (!wsCompletedPromise.isDone()) { + val _ = wsCompletedPromise.setFailure(t) + } } override def onComplete(): Unit = { subscriber.sub.onComplete() - val _ = onCancel.cancel(true) + if (!wsCompletedPromise.isDone()) { + val _ = wsCompletedPromise.setSuccess() + } } override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { @@ -112,19 +121,15 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( } else s } -/** This wrapped is needed to intercept the logic of StreamSubscription which calls cancel() in case of fatally failing streams. This makes - * errors get swallowed, so we replace delegate.onCancel() with our own onCancel callback that would trigger custom handling logging in - * Netty. Additionally, the stream will fail properly and any errors from the pipeline will be logged. - * +/** A special wrapper used to override internal logic of fs2, which calls cancel() silently when internal stream failures happen, causing + * the subscription to close the channel and stop the subscriber in such a way that errors can't get handled properly. With this wrapper we + * intentionally don't do anything on cancel(), so that the stream continues to fail properly on errors. We are handling cancelation + * manually with a channel promise passed to the processor logic. * @param delegate * a channel subscription which we don't want to notify about cancelation. - * @param onCancel - * our custom cancellation callback. */ -class ChannelAwareSubscription(delegate: Subscription, onCancel: ChannelPromise) extends Subscription { - override def cancel(): Unit = { - val _ = onCancel.setSuccess() - } +class NonCancelingSubscription(delegate: Subscription) extends Subscription { + override def cancel(): Unit = () override def request(n: Long): Unit = { delegate.request(n) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 91530fdfc5..bab9976453 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -144,7 +144,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .send(backend) .map( _.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right( - List(Right("echo: test1"), Right("echo: test2"), Left(WebSocketFrame.close.statusCode)) + List(Right("echo: test1"), Right("echo: test2"), Left(1011)) ) ) }, From baa982017cfd4a4d58b0720b48395ad11b8748ea Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 10:15:16 +0100 Subject: [PATCH 06/31] Add volatiles where needed --- .../server/netty/cats/internal/WebSocketPipeProcessor.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 732fef8f56..441d5d60ef 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -29,9 +29,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], wsCompletedPromise: ChannelPromise ) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { - private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ + @volatile private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() - private var subscription: Subscription = _ + @volatile private var subscription: Subscription = _ private val logger = LoggerFactory.getLogger(getClass.getName) From d9b9a696e1d4cd769e6bbad09a72f76f67148e6d Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 20:39:11 +0100 Subject: [PATCH 07/31] Auto-ping --- .../server/http4s/Http4sServerTest.scala | 36 +++++++++++++++++++ .../internal/ReactiveWebSocketHandler.scala | 4 ++- .../internal/WebSocketAutoPingHandler.scala | 4 ++- .../server/tests/ServerWebSocketTests.scala | 29 ++++++++++++++- 4 files changed, 70 insertions(+), 3 deletions(-) diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 844e628024..4844a56c3b 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -5,6 +5,7 @@ import cats.effect._ import cats.effect.unsafe.implicits.global import cats.syntax.all._ import fs2.Pipe +import fs2.Stream import org.http4s.blaze.server.BlazeServerBuilder import org.http4s.server.Router import org.http4s.server.ContextMiddleware @@ -130,6 +131,41 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi .get(baseUri) .send(backend) .map(_.body.right.toOption.value shouldBe List(sse1, sse2)) + }, + createServerTest.testServer( + endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])), + "canceled server WebSocket stream" + )((_: Unit) => + pureResult( + ( + (in: Stream[IO, String]) => + Stream + .eval(Ref.of[IO, Int](0)) + .flatMap { counter => + Stream.repeatEval( + counter.get.flatMap(c => if (c > 2) IO.canceled.as("canceled") else counter.update(_ + 1) >> IO.pure(c.toString)) + ) + } + .concurrently(in.as(())) + ).asRight[Unit] + ) + ) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("start") + m1 <- ws.eitherClose(ws.receiveText()) + m2 <- ws.eitherClose(ws.receiveText()) + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map( + _.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right( + List(Right("0"), Right("1"), Right("2"), Left(WebSocketFrame.close.statusCode)) + ) + ) } ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index 2b82684192..e00d3467cd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -100,7 +100,9 @@ class ReactiveWebSocketHandler[F[_]]( ) ) r.autoPing.foreach { case (interval, pingMsg) => - ctx.pipeline().addFirst("wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) + ctx + .pipeline() + .addAfter("wsControlFrameHandler", "wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) } // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly r.channelPromise.setSuccess() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala index 189f6daff0..401d668ca5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala @@ -2,7 +2,7 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.handler.codec.http.websocketx.PingWebSocketFrame +import io.netty.handler.codec.http.websocketx._ import org.slf4j.LoggerFactory import java.util.concurrent.{ScheduledFuture, TimeUnit} @@ -22,6 +22,7 @@ class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebS private val logger = LoggerFactory.getLogger(getClass.getName) override def handlerAdded(ctx: ChannelHandlerContext): Unit = { + super.handlerAdded(ctx) if (ctx.channel.isActive) { logger.debug(s"STARTING WebSocket Ping scheduler for channel ${ctx.channel}, interval = $pingInterval") val sendPing: Runnable = new Runnable { @@ -30,6 +31,7 @@ class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebS val _ = ctx.writeAndFlush(nettyFrame.retain()) } } + // FIXME should not start before the handshake response is sent! pingTask = ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index bab9976453..2f7a24c6a7 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -4,6 +4,7 @@ import cats.effect.IO import cats.syntax.all._ import io.circe.generic.auto._ import org.scalatest.matchers.should.Matchers._ +import org.scalatest.EitherValues import sttp.capabilities.{Streams, WebSockets} import sttp.client3._ import sttp.monad.MonadError @@ -16,13 +17,14 @@ import sttp.tapir.server.tests.ServerMetricsTest._ import sttp.tapir.tests.Test import sttp.tapir.tests.data.Fruit import sttp.ws.{WebSocket, WebSocketFrame} +import scala.concurrent.duration._ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, S with WebSockets, OPTIONS, ROUTE], val streams: S )(implicit m: MonadError[F] -) { +) extends EitherValues { import createServerTest._ def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] @@ -148,6 +150,31 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) ) }, + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(Some((200.millis, WebSocketFrame.ping))) + ), + "auto ping" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- IO.sleep(250.millis) + _ <- ws.sendText("test2") + m1 <- ws.receive() + m2 <- ws.receive() + _ <- ws.sendText("test3") + m3 <- ws.receive() + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map((r: Response[Either[String, List[WebSocketFrame]]]) => + assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r") + ) + }, testServer( endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), "empty client stream" From 4d1e11d6d8884ee68ac98c849c7498adf9a8bedf Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 22:52:18 +0100 Subject: [PATCH 08/31] Adjust netty-zio --- .../tapir/server/netty/zio/internal/ZioStreamCompatible.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 14ad954784..bf0f99fe26 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -12,6 +12,7 @@ import sttp.tapir.server.netty.internal._ import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream +import io.netty.channel.ChannelHandlerContext private[zio] object ZioStreamCompatible { @@ -68,7 +69,8 @@ private[zio] object ZioStreamCompatible { override def asWsProcessor[REQ, RESP]( pipe: Stream[Throwable, REQ] => Stream[Throwable, RESP], - o: WebSocketBodyOutput[Stream[Throwable, REQ] => Stream[Throwable, RESP], REQ, RESP, ?, ZioStreams] + o: WebSocketBodyOutput[Stream[Throwable, REQ] => Stream[Throwable, RESP], REQ, RESP, ?, ZioStreams], + ctx: ChannelHandlerContext ): Processor[WebSocketFrame, WebSocketFrame] = throw new UnsupportedOperationException("TODO") } From f9bb21e82a7c61f8644f0f1e850d9bb7e9d2a7ea Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 23:16:42 +0100 Subject: [PATCH 09/31] Don't decode Ping --- .../WebSocketControlFrameHandler.scala | 9 ++-- .../server/tests/ServerWebSocketTests.scala | 48 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala index 098e76bcfe..ffdb7bbe06 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -6,16 +6,17 @@ import io.netty.handler.codec.http.websocketx.PingWebSocketFrame import io.netty.handler.codec.http.websocketx.PongWebSocketFrame import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame -/** - * Handles Ping, Pong, and Close frames for WebSockets. +/** Handles Ping, Pong, and Close frames for WebSockets. */ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean) extends ChannelInboundHandlerAdapter { override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { msg match { - case ping: PingWebSocketFrame if autoPongOnPing => - val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + case ping: PingWebSocketFrame => + if (autoPongOnPing) { + val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + } case pong: PongWebSocketFrame if !ignorePong => val _ = ctx.fireChannelRead(pong) case close: CloseWebSocketFrame => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 2f7a24c6a7..aff0ed2e27 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -175,6 +175,54 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r") ) }, + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(true) + ), + "pong on ping" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes())) + m1 <- ws.receive() + _ <- ws.sendText("test2") + m2 <- ws.receive() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map((r: Response[Either[String, List[WebSocketFrame]]]) => + assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Pong]), s"Missing Pong frame in WS responses: $r") + ) + }, + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + ), + "not pong on ping if disabled" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Ping("test-ping-text".getBytes())) + m1 <- ws.receiveText() + _ <- ws.sendText("test2") + m2 <- ws.receiveText() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map( + _.body shouldBe Right(List("echo: test1", "echo: test2")) + ) + }, testServer( endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), "empty client stream" From f420423c7efb6379945a2b084eb717336d285aa4 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 23:42:45 +0100 Subject: [PATCH 10/31] Remove cancelation test --- .../server/http4s/Http4sServerTest.scala | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 4844a56c3b..745ec69f68 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -131,41 +131,6 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi .get(baseUri) .send(backend) .map(_.body.right.toOption.value shouldBe List(sse1, sse2)) - }, - createServerTest.testServer( - endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])), - "canceled server WebSocket stream" - )((_: Unit) => - pureResult( - ( - (in: Stream[IO, String]) => - Stream - .eval(Ref.of[IO, Int](0)) - .flatMap { counter => - Stream.repeatEval( - counter.get.flatMap(c => if (c > 2) IO.canceled.as("canceled") else counter.update(_ + 1) >> IO.pure(c.toString)) - ) - } - .concurrently(in.as(())) - ).asRight[Unit] - ) - ) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("start") - m1 <- ws.eitherClose(ws.receiveText()) - m2 <- ws.eitherClose(ws.receiveText()) - m3 <- ws.eitherClose(ws.receiveText()) - } yield List(m1, m2, m3) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map( - _.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right( - List(Right("0"), Right("1"), Right("2"), Left(WebSocketFrame.close.statusCode)) - ) - ) } ) From aabdf15e922f30e9e9b060fd8a4e4742dfb81421 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 26 Mar 2024 23:58:08 +0100 Subject: [PATCH 11/31] Adjust error handling test to http4s --- .../tapir/server/tests/ServerWebSocketTests.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index aff0ed2e27..37dee6d201 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -17,6 +17,7 @@ import sttp.tapir.server.tests.ServerMetricsTest._ import sttp.tapir.tests.Test import sttp.tapir.tests.data.Fruit import sttp.ws.{WebSocket, WebSocketFrame} + import scala.concurrent.duration._ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( @@ -144,11 +145,13 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( }) .get(baseUri.scheme("ws")) .send(backend) - .map( - _.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right( - List(Right("echo: test1"), Right("echo: test2"), Left(1011)) - ) - ) + .map { r => + val results = r.body.map(_.map(_.left.map(_.statusCode))).value + results.take(2) shouldBe + List(Right("echo: test1"), Right("echo: test2")) + val closeCode = results.last.left.value + assert(closeCode == 1000 || closeCode == 1011) // some servers respond with Close(normal), some with Close(error) + } }, testServer( endpoint.out( From da2860f793b18e45a9fbe1a8b2e40a2d8712562f Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 00:50:53 +0100 Subject: [PATCH 12/31] Correctly read bytes from Netty Frames --- .../netty/internal/WebSocketFrameConverters.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala index 2d5bc06615..e2d866aeac 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala @@ -1,19 +1,25 @@ package sttp.tapir.server.netty.internal +import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.handler.codec.http.websocketx._ import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} import sttp.ws.WebSocketFrame -import io.netty.buffer.Unpooled object WebSocketFrameConverters { + def getBytes(buf: ByteBuf): Array[Byte] = { + val bytes = new Array[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes + } + def nettyFrameToFrame(nettyFrame: NettyWebSocketFrame): WebSocketFrame = { nettyFrame match { case text: TextWebSocketFrame => WebSocketFrame.Text(text.text, text.isFinalFragment, Some(text.rsv)) case close: CloseWebSocketFrame => WebSocketFrame.Close(close.statusCode, close.reasonText) - case ping: PingWebSocketFrame => WebSocketFrame.Ping(ping.content().nioBuffer().array()) - case pong: PongWebSocketFrame => WebSocketFrame.Pong(pong.content().nioBuffer().array()) - case _ => WebSocketFrame.Binary(nettyFrame.content().nioBuffer().array(), nettyFrame.isFinalFragment, Some(nettyFrame.rsv)) + case ping: PingWebSocketFrame => WebSocketFrame.Ping(getBytes(ping.content())) + case pong: PongWebSocketFrame => WebSocketFrame.Pong(getBytes(pong.content())) + case _ => WebSocketFrame.Binary(getBytes(nettyFrame.content()), nettyFrame.isFinalFragment, Some(nettyFrame.rsv)) } } From 9eeb2aceff49f5b2df5c06c0d424f2545e5bcf03 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 01:36:57 +0100 Subject: [PATCH 13/31] Parameterize tests --- .../server/akkahttp/AkkaHttpServerTest.scala | 2 +- .../server/http4s/Http4sServerTest.scala | 2 +- .../http4s/ztapir/ZHttp4sServerTest.scala | 2 +- .../netty/cats/NettyCatsServerTest.scala | 2 +- .../WebSocketControlFrameHandler.scala | 6 +- .../pekkohttp/PekkoHttpServerTest.scala | 2 +- .../tapir/server/play/PlayServerTest.scala | 2 +- .../tapir/server/play/PlayServerTest.scala | 2 +- .../server/tests/ServerWebSocketTests.scala | 175 ++++++++++++------ .../vertx/cats/CatsVertxServerTest.scala | 2 +- .../tapir/server/vertx/VertxServerTest.scala | 4 +- .../server/vertx/zio/ZioVertxServerTest.scala | 2 +- .../server/ziohttp/ZioHttpServerTest.scala | 2 +- 13 files changed, 132 insertions(+), 73 deletions(-) diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 3cb5e63b73..49e0fc2aaa 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -157,7 +157,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ - new ServerWebSocketTests(createServerTest, AkkaStreams) { + new ServerWebSocketTests(createServerTest, AkkaStreams, autoPing = false, failingPipe = true, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 745ec69f68..530aad401f 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -139,7 +139,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { + new ServerWebSocketTests(createServerTest, Fs2Streams[IO], autoPing = true, failingPipe = true, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty }.tests() ++ diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index 99ff7c2d4d..b2cfc5ae58 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -55,7 +55,7 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = false, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index e3252a25a1..49dfc5d479 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -41,7 +41,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ - new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { + new ServerWebSocketTests(createServerTest, Fs2Streams[IO], autoPing = true, failingPipe = true, handlePong = true) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty }.tests() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala index ffdb7bbe06..cd7e6ca919 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -17,8 +17,10 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec if (autoPongOnPing) { val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) } - case pong: PongWebSocketFrame if !ignorePong => - val _ = ctx.fireChannelRead(pong) + case pong: PongWebSocketFrame => + if (!ignorePong) { + val _ = ctx.fireChannelRead(pong) + } case close: CloseWebSocketFrame => if (decodeCloseRequests) { // Passing the Close frame for further processing diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index 7f15eecf29..a48dc86063 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -105,7 +105,7 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ - new ServerWebSocketTests(createServerTest, PekkoStreams) { + new ServerWebSocketTests(createServerTest, PekkoStreams, autoPing = false, failingPipe = true, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 3662a63fd3..af479d230e 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -122,7 +122,7 @@ class PlayServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, PekkoStreams) { + new ServerWebSocketTests(createServerTest, PekkoStreams, autoPing = false, failingPipe = true, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 096b077f44..2634bc5fa9 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -117,7 +117,7 @@ class PlayServerTest extends TestSuite { new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, AkkaStreams) { + new ServerWebSocketTests(createServerTest, AkkaStreams, autoPing = false, failingPipe = true, handlePong = false) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 37dee6d201..7c5fac8d63 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -19,10 +19,14 @@ import sttp.tapir.tests.data.Fruit import sttp.ws.{WebSocket, WebSocketFrame} import scala.concurrent.duration._ +import sttp.tapir.model.UnsupportedWebSocketFrameException abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, S with WebSockets, OPTIONS, ROUTE], - val streams: S + val streams: S, + autoPing: Boolean, + failingPipe: Boolean, + handlePong: Boolean )(implicit m: MonadError[F] ) extends EitherValues { @@ -123,61 +127,6 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) ) }, - testServer( - endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), - "failing pipe" - )((_: Unit) => - pureResult(functionToPipe[String, String] { - case "error-trigger" => throw new Exception("Boom!") - case msg => s"echo: $msg" - }.asRight[Unit]) - ) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("test1") - _ <- ws.sendText("test2") - _ <- ws.sendText("error-trigger") - m1 <- ws.eitherClose(ws.receiveText()) - m2 <- ws.eitherClose(ws.receiveText()) - m3 <- ws.eitherClose(ws.receiveText()) - } yield List(m1, m2, m3) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map { r => - val results = r.body.map(_.map(_.left.map(_.statusCode))).value - results.take(2) shouldBe - List(Right("echo: test1"), Right("echo: test2")) - val closeCode = results.last.left.value - assert(closeCode == 1000 || closeCode == 1011) // some servers respond with Close(normal), some with Close(error) - } - }, - testServer( - endpoint.out( - webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) - .autoPing(Some((200.millis, WebSocketFrame.ping))) - ), - "auto ping" - )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("test1") - _ <- IO.sleep(250.millis) - _ <- ws.sendText("test2") - m1 <- ws.receive() - m2 <- ws.receive() - _ <- ws.sendText("test3") - m3 <- ws.receive() - } yield List(m1, m2, m3) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map((r: Response[Either[String, List[WebSocketFrame]]]) => - assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r") - ) - }, testServer( endpoint.out( webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) @@ -199,7 +148,10 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .get(baseUri.scheme("ws")) .send(backend) .map((r: Response[Either[String, List[WebSocketFrame]]]) => - assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Pong]), s"Missing Pong frame in WS responses: $r") + assert(r.body.value exists { + case WebSocketFrame.Pong(array) => array sameElements "test-ping-text".getBytes + case _ => false + }, s"Missing Pong(test-ping-text) in ${r.body}") ) }, testServer( @@ -249,7 +201,112 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .send(backend) .map(_.body shouldBe Left("Not a WS!")) } - ) + ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests + + val autoPingTests = + if (autoPing) + List( + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(Some((200.millis, WebSocketFrame.ping))) + ), + "auto ping" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- IO.sleep(250.millis) + _ <- ws.sendText("test2") + m1 <- ws.receive() + m2 <- ws.receive() + _ <- ws.sendText("test3") + m3 <- ws.receive() + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map((r: Response[Either[String, List[WebSocketFrame]]]) => + assert(r.body.value.exists(_.isInstanceOf[WebSocketFrame.Ping]), s"Missing Ping frame in WS responses: $r") + ) + } + ) + else List.empty + + val failingPipeTests = + if (failingPipe) + List( + testServer( + endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)), + "failing pipe" + )((_: Unit) => + pureResult(functionToPipe[String, String] { + case "error-trigger" => throw new Exception("Boom!") + case msg => s"echo: $msg" + }.asRight[Unit]) + ) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.sendText("test2") + _ <- ws.sendText("error-trigger") + m1 <- ws.eitherClose(ws.receiveText()) + m2 <- ws.eitherClose(ws.receiveText()) + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { r => + val results = r.body.map(_.map(_.left.map(_.statusCode))).value + results.take(2) shouldBe + List(Right("echo: test1"), Right("echo: test2")) + val closeCode = results.last.left.value + assert(closeCode == 1000 || closeCode == 1011) // some servers respond with Close(normal), some with Close(error) + } + } + ) + else List.empty - // TODO: tests for ping/pong (control frames handling) + val handlePongTests = if (handlePong) List( + testServer( + { + implicit def textOrPongWebSocketFrame[A, CF <: CodecFormat](implicit + stringCodec: Codec[String, A, CF] + ): Codec[WebSocketFrame, A, CF] = + Codec // A custom codec to handle Pongs + .id[WebSocketFrame, CF](stringCodec.format, Schema.string) + .mapDecode { + case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p) + case WebSocketFrame.Pong(payload) => + println(payload.length) + stringCodec.decode(new String(payload)) + case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f)) + }(a => WebSocketFrame.text(stringCodec.encode(a))) + .schema(stringCodec.schema) + + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .ignorePong(false) + ) + }, + "not ignore pong" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Pong("test-pong-text".getBytes())) + m1 <- ws.receiveText() + _ <- ws.sendText("test2") + m2 <- ws.receiveText() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.body shouldBe Right(List("echo: test1", "echo: test-pong-text"))) + }) else List.empty } diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index 9a96f67fd1..aac2e40122 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -37,7 +37,7 @@ class CatsVertxServerTest extends TestSuite { partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams.apply[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO]) { + new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO], autoPing = false, failingPipe = true, handlePong = true) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty }.tests() diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index 8a7a8f57ce..394a8d22a2 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -27,7 +27,7 @@ class VertxServerTest extends TestSuite { def drainVertx[T](source: ReadStream[T]): Future[Unit] = { val p = Promise[Unit]() // Handler for stream data - do nothing with the data - val dataHandler: Handler[T] = (_: T) => () + val dataHandler: Handler[T] = (_: T) => () // End handler - complete the promise when the stream ends val endHandler: Handler[Void] = (_: Void) => p.success(()) @@ -53,7 +53,7 @@ class VertxServerTest extends TestSuite { partContentTypeHeaderSupport = true, partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(VertxStreams)(drainVertx[Buffer]) ++ - (new ServerWebSocketTests(createServerTest, VertxStreams) { + (new ServerWebSocketTests(createServerTest, VertxStreams, autoPing = false, failingPipe = false, handlePong = true) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() }).tests() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index 0a8c27f4e4..f2ca8846c5 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -43,7 +43,7 @@ class ZioVertxServerTest extends TestSuite with OptionValues { partOtherHeaderSupport = false ).tests() ++ additionalTests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = true, handlePong = true) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 828290d902..7e6043166a 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -266,7 +266,7 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { + new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = false, handlePong = false) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ From b4e719261fc56cfdf6371ed6d3ea41a81beae6c4 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 01:38:47 +0100 Subject: [PATCH 14/31] Remove println --- .../scala/sttp/tapir/server/tests/ServerWebSocketTests.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 7c5fac8d63..5596457a85 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -281,7 +281,6 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .mapDecode { case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p) case WebSocketFrame.Pong(payload) => - println(payload.length) stringCodec.decode(new String(payload)) case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f)) }(a => WebSocketFrame.text(stringCodec.encode(a))) From 83fa2f807cd1399661dfce14525481fd335b6bf7 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 02:00:00 +0100 Subject: [PATCH 15/31] Fix memory leak --- .../netty/internal/WebSocketControlFrameHandler.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala index cd7e6ca919..f8c16bfaf0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -16,10 +16,14 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec case ping: PingWebSocketFrame => if (autoPongOnPing) { val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + } else { + val _ = ping.content().release() } - case pong: PongWebSocketFrame => + case pong: PongWebSocketFrame => if (!ignorePong) { val _ = ctx.fireChannelRead(pong) + } else { + val _ = pong.content().release() } case close: CloseWebSocketFrame => if (decodeCloseRequests) { From 4705519ea4e9d7b01ddc375627d5ec7a2886e80a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 02:19:01 +0100 Subject: [PATCH 16/31] Cleanup and minor tweaks --- .../cats/internal/Fs2StreamCompatible.scala | 9 +- .../server/netty/NettyResponseContent.scala | 6 +- .../netty/internal/NettyServerHandler.scala | 24 ++--- .../internal/NettyToStreamsResponseBody.scala | 7 +- .../netty/internal/StreamCompatible.scala | 2 +- .../internal/WebSocketAutoPingHandler.scala | 1 - .../server/tests/ServerWebSocketTests.scala | 92 ++++++++++--------- 7 files changed, 75 insertions(+), 66 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index e74794a1e0..1b5ea432a7 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -71,20 +71,19 @@ object Fs2StreamCompatible { o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], ctx: ChannelHandlerContext ): Processor[WebSocketFrame, WebSocketFrame] = { - val onCancelPromise = ctx.newPromise() - onCancelPromise.addListener((f: ChannelFuture) => { + val processorPromise = ctx.newPromise() + processorPromise.addListener((f: ChannelFuture) => { // A special callback that has to be used when a SteramSubscription cancels or fails. // This can happen in case of errors in the pipeline which are not signalled correctly, // like throwing exceptions directly. // Without explicit Close frame a client may hang on waiting and not knowing about closed channel. if (f.isCancelled) { val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Canceled")) - } - else if (!f.isSuccess) { + } else if (!f.isSuccess) { val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Error")) } }) - new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, onCancelPromise) + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, processorPromise) } private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index d3aea1dda6..77882e8dfd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -3,11 +3,11 @@ package sttp.tapir.server.netty import io.netty.buffer.ByteBuf import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.HttpContent -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import org.reactivestreams.Publisher -import org.reactivestreams.Processor import io.netty.handler.codec.http.websocketx.WebSocketFrame +import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.reactivestreams.{Processor, Publisher} import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} + import scala.concurrent.duration.FiniteDuration sealed trait NettyResponseContent { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 3da92d8f56..1ec09f7d27 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,7 +4,6 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher @@ -16,7 +15,8 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ByteBufNettyResponseContent, ChunkedFileNettyResponseContent, ChunkedStreamNettyResponseContent, - ReactivePublisherNettyResponseContent + ReactivePublisherNettyResponseContent, + ReactiveWebSocketProcessorNettyResponseContent } import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} @@ -26,8 +26,6 @@ import scala.collection.mutable.{Queue => MutableQueue} import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success} -import sttp.tapir.EndpointInput.AuthType.Http -import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent /** @param unsafeRunAsync * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => @@ -216,10 +214,12 @@ class NettyServerHandler[F[_]]( }, wsHandler = (channelPromise) => { - logger.error("Unexpected WebSocket processor response received in NettyServerHandler, it should be handled only in the ReactiveWebSocketHandler") - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res) + logger.error( + "Unexpected WebSocket processor response received in NettyServerHandler, it should be handled only in the ReactiveWebSocketHandler" + ) + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res) }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( @@ -251,10 +251,10 @@ class NettyServerHandler[F[_]]( val values = function(ctx) values match { - case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) - case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) - case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) - case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) + case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) + case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r.channelPromise) } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 998efd390d..507eee8a70 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -6,13 +6,16 @@ import sttp.capabilities.Streams import sttp.model.HasHeaders import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent.{ + ByteBufNettyResponseContent, + ReactivePublisherNettyResponseContent, + ReactiveWebSocketProcessorNettyResponseContent +} import sttp.tapir.server.netty.internal.NettyToResponseBody._ -import sttp.tapir.server.netty.NettyResponseContent.{ByteBufNettyResponseContent, ReactivePublisherNettyResponseContent} import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.nio.ByteBuffer import java.nio.charset.Charset -import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent /** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries like fs2 * or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. Other kinds diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 01f49586a6..bbcdb4ea3e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty.internal +import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.http.HttpContent import io.netty.handler.codec.http.websocketx.WebSocketFrame import org.reactivestreams.{Processor, Publisher} @@ -7,7 +8,6 @@ import sttp.capabilities.Streams import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream -import io.netty.channel.ChannelHandlerContext /** Operations on streams that have to be implemented for each streaming integration (fs2, zio-streams, etc) used by Netty backends. This * includes conversions like building a stream from a `File`, an `InputStream`, or a reactive `Publisher`. We also need implementation of a diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala index 401d668ca5..c62eb6ba9a 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala @@ -31,7 +31,6 @@ class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebS val _ = ctx.writeAndFlush(nettyFrame.retain()) } } - // FIXME should not start before the handshake response is sent! pingTask = ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 5596457a85..09faceac0a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -148,10 +148,13 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .get(baseUri.scheme("ws")) .send(backend) .map((r: Response[Either[String, List[WebSocketFrame]]]) => - assert(r.body.value exists { - case WebSocketFrame.Pong(array) => array sameElements "test-ping-text".getBytes - case _ => false - }, s"Missing Pong(test-ping-text) in ${r.body}") + assert( + r.body.value exists { + case WebSocketFrame.Pong(array) => array sameElements "test-ping-text".getBytes + case _ => false + }, + s"Missing Pong(test-ping-text) in ${r.body}" + ) ) }, testServer( @@ -209,7 +212,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( testServer( endpoint.out( webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) - .autoPing(Some((200.millis, WebSocketFrame.ping))) + .autoPing(Some((50.millis, WebSocketFrame.ping))) ), "auto ping" )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => @@ -217,7 +220,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .response(asWebSocket { (ws: WebSocket[IO]) => for { _ <- ws.sendText("test1") - _ <- IO.sleep(250.millis) + _ <- IO.sleep(150.millis) _ <- ws.sendText("test2") m1 <- ws.receive() m2 <- ws.receive() @@ -234,6 +237,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) else List.empty + // Optional, because some backends don't handle exceptions in the pipe gracefully, they just swallow it silently and hang forever val failingPipeTests = if (failingPipe) List( @@ -270,42 +274,46 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) else List.empty - val handlePongTests = if (handlePong) List( - testServer( - { - implicit def textOrPongWebSocketFrame[A, CF <: CodecFormat](implicit - stringCodec: Codec[String, A, CF] - ): Codec[WebSocketFrame, A, CF] = - Codec // A custom codec to handle Pongs - .id[WebSocketFrame, CF](stringCodec.format, Schema.string) - .mapDecode { - case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p) - case WebSocketFrame.Pong(payload) => - stringCodec.decode(new String(payload)) - case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f)) - }(a => WebSocketFrame.text(stringCodec.encode(a))) - .schema(stringCodec.schema) + val handlePongTests = + if (handlePong) + List( + testServer( + { + implicit def textOrPongWebSocketFrame[A, CF <: CodecFormat](implicit + stringCodec: Codec[String, A, CF] + ): Codec[WebSocketFrame, A, CF] = + Codec // A custom codec to handle Pongs + .id[WebSocketFrame, CF](stringCodec.format, Schema.string) + .mapDecode { + case WebSocketFrame.Text(p, _, _) => stringCodec.decode(p) + case WebSocketFrame.Pong(payload) => + stringCodec.decode(new String(payload)) + case f => DecodeResult.Error(f.toString, new UnsupportedWebSocketFrameException(f)) + }(a => WebSocketFrame.text(stringCodec.encode(a))) + .schema(stringCodec.schema) - endpoint.out( - webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) - .autoPing(None) - .ignorePong(false) + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .ignorePong(false) + ) + }, + "not ignore pong" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.send(WebSocketFrame.Pong("test-pong-text".getBytes())) + m1 <- ws.receiveText() + _ <- ws.sendText("test2") + m2 <- ws.receiveText() + } yield List(m1, m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.body shouldBe Right(List("echo: test1", "echo: test-pong-text"))) + } ) - }, - "not ignore pong" - )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("test1") - _ <- ws.send(WebSocketFrame.Pong("test-pong-text".getBytes())) - m1 <- ws.receiveText() - _ <- ws.sendText("test2") - m2 <- ws.receiveText() - } yield List(m1, m2) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map(_.body shouldBe Right(List("echo: test1", "echo: test-pong-text"))) - }) else List.empty + else List.empty } From 929d708f6d65d84687ac9df510a7afa205955397 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 08:46:33 +0100 Subject: [PATCH 17/31] Documentation --- build.sbt | 1 + doc/server/netty.md | 65 ++++++++++++++++ .../examples/WebSocketsNettyCatsServer.scala | 75 ------------------- 3 files changed, 66 insertions(+), 75 deletions(-) delete mode 100644 examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala diff --git a/build.sbt b/build.sbt index 3cccfc5086..feed8f00ed 100644 --- a/build.sbt +++ b/build.sbt @@ -2196,6 +2196,7 @@ lazy val documentation: ProjectMatrix = (projectMatrix in file("generated-doc")) sprayJson, http4sClient, http4sServerZio, + nettyServerCats, sttpClient, playClient, sttpStubServer, diff --git a/doc/server/netty.md b/doc/server/netty.md index 3fd9cec741..cbeee80c9c 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -83,6 +83,71 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None) NettyFutureServer(NettyConfig.default.socketBacklog(256)) ``` +## Web sockets + +The netty-cats interpreter supports web sockets, with pipes of type `fs2.Pipe[F, REQ, RESP]`. See [web sockets](../endpoint/websockets.md) +for more details. + +To create a web socket endpoint, use Tapir's `out(webSocketBody)` output type: + +```scala mdoc:compile-only +import cats.effect.kernel.Resource +import cats.effect.{IO, ResourceApp} +import cats.syntax.all._ +import fs2.Pipe +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir._ +import sttp.tapir.server.netty.cats.NettyCatsServer +import sttp.ws.WebSocketFrame + +import scala.concurrent.duration._ + +object WebSocketsNettyCatsServer extends ResourceApp.Forever { + + // Web socket endpoint + val wsEndpoint = + endpoint.get + .in("ws") + .out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO]) + .concatenateFragmentedFrames(false) // All these options are supported by tapir-netty + .ignorePong(true) + .autoPongOnPing(true) + .decodeCloseRequests(false) + .decodeCloseResponses(false) + .autoPing(Some((10.seconds, WebSocketFrame.Ping("ping-content".getBytes)))) + ) + + // Your processor transforming a stream of requests into a stream of responses + val pipe: Pipe[IO, String, String] = requestStream => requestStream.evalMap(str => IO.pure(str.toUpperCase)) + // Alternatively, requests can be ignored and the backend can be turned into a stream emitting frames to the client: + // val pipe: Pipe[IO, String, String] = requestStream => someDataEmittingStream.concurrently(requestStream.as(())) + + val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => IO.pure(pipe)) + + // A regular /GET endpoint + val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = + endpoint.get.in("hello").in(query[String]("name")).out(stringBody) + + val helloWorldServerEndpoint = helloWorldEndpoint + .serverLogicSuccess(name => IO.pure(s"Hello, $name!")) + + override def run(args: List[String]) = NettyCatsServer + .io() + .flatMap { server => + Resource + .make( + server + .port(8080) + .host("localhost") + .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) + .start() + )(_.stop()) + .as(()) + } +} +``` + ## Graceful shutdown A Netty server can be gracefully closed using the function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources. diff --git a/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala b/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala deleted file mode 100644 index 50729f0c2e..0000000000 --- a/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala +++ /dev/null @@ -1,75 +0,0 @@ -package sttp.tapir.examples - -import cats.effect.{IO, IOApp} -import sttp.client3._ -import sttp.model.StatusCode -import sttp.tapir.server.netty.cats.NettyCatsServer -import sttp.tapir.* -import scala.concurrent.duration._ -import sttp.capabilities.fs2.Fs2Streams -import sttp.ws.WebSocket -import sttp.client3.pekkohttp.PekkoHttpBackend -import scala.concurrent.Future - -object WebSocketsNettyCatsServer extends IOApp.Simple { - // One endpoint on GET /hello with query parameter `name` - val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = - endpoint.get.in("hello").in(query[String]("name")).out(stringBody) - - val wsEndpoint = - endpoint.get.in("ws").out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])) - - val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => - IO.pure(in => in.evalMap(str => IO.println(s"responding with ${str.toUpperCase}") >> IO.pure(str.toUpperCase()))) - ) - // Just returning passed name with `Hello, ` prepended - val helloWorldServerEndpoint = helloWorldEndpoint - .serverLogic(name => IO.pure[Either[Unit, String]](Right(s"Hello, $name!"))) - - private val declaredPort = 9090 - private val declaredHost = "localhost" - - // Creating handler for netty bootstrap - override def run = NettyCatsServer - .io() - .use { server => - for { - binding <- server - .port(declaredPort) - .host(declaredHost) - .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) - .start() - result <- IO - .fromFuture(IO.delay { - val port = binding.port - val host = binding.hostName - println(s"Server started at port = ${binding.port}") - import scala.concurrent.ExecutionContext.Implicits.global - def useWebSocket(ws: WebSocket[Future]): Future[Unit] = { - def send(i: Int) = ws.sendText(s"Hello $i!") - def receive() = ws.receiveText().map(t => println(s"Client RECEIVED: $t")) - for { - _ <- send(1) - _ <- receive() - _ <- send(2) - _ <- send(3) - _ <- receive() - } yield () - } - val backend = PekkoHttpBackend() - - val url = uri"ws://$host:$port/ws" - val allGood = uri"http://$host:$port/hello?name=Netty" - basicRequest.response(asStringAlways).get(allGood).send(backend).map(r => println(r.body)) - .flatMap { _ => - basicRequest - .response(asWebSocket(useWebSocket)) - .get(url) - .send(backend) - } - .andThen { case _ => backend.close() } - }) - .guarantee(binding.stop()) - } yield result - } -} From 84977908e7b3825492ae0e198714b29d81d76fd0 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 10:26:26 +0100 Subject: [PATCH 18/31] Organize handlers --- .../scala/sttp/tapir/server/netty/NettyConfig.scala | 10 +++++++--- .../tapir/server/netty/internal/NettyBootstrap.scala | 4 +++- .../netty/internal/ReactiveWebSocketHandler.scala | 12 +++++++----- .../sttp/tapir/server/netty/internal/package.scala | 3 +++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 1d6d66144d..115c387be1 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -10,6 +10,7 @@ import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext import org.playframework.netty.http.HttpStreamsServerHandler import sttp.tapir.server.netty.NettyConfig.EventLoopConfig +import sttp.tapir.server.netty.internal._ import scala.concurrent.duration._ @@ -105,6 +106,9 @@ case class NettyConfig( } object NettyConfig { + private val WebSocketHandlerName = "wsHandler" + val StreamsHandlerName = "streamsHandler" + def default: NettyConfig = NettyConfig( host = "localhost", port = 8080, @@ -126,9 +130,9 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler, wsHandler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast("serverCodecHandler", new HttpServerCodec()) - pipeline.addLast("streamsHandler", new HttpStreamsServerHandler()) - pipeline.addLast("wsHandler", wsHandler) + pipeline.addLast(ServerCodecHandlerName, new HttpServerCodec()) + pipeline.addLast(StreamsHandlerName, new HttpStreamsServerHandler()) + pipeline.addLast(WebSocketHandlerName, wsHandler) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 68835c8f07..5c864b3fa9 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -9,6 +9,8 @@ import java.net.{InetSocketAddress, SocketAddress} object NettyBootstrap { + private val ReadTimeoutHandlerName = "readTimeoutHandler" + def apply[F[_]]( nettyConfig: NettyConfig, handler: => NettyServerHandler[F], @@ -29,7 +31,7 @@ object NettyBootstrap { nettyConfig.requestTimeout match { case Some(requestTimeout) => nettyConfigBuilder( - ch.pipeline().addLast("readTimeoutHandler", new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), + ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), handler, wsHandler ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index e00d3467cd..81c81459fa 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -2,8 +2,9 @@ package sttp.tapir.server.netty.internal import io.netty.channel.group.ChannelGroup import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory +import io.netty.handler.timeout.ReadTimeoutHandler import io.netty.util.ReferenceCountUtil import org.playframework.netty.http.DefaultWebSocketHttpResponse import org.slf4j.LoggerFactory @@ -37,6 +38,7 @@ class ReactiveWebSocketHandler[F[_]]( private[this] var eventLoopContext: ExecutionContext = _ private val logger = LoggerFactory.getLogger(getClass.getName) + private val WebSocketAutoPingHandlerName = "wsAutoPingHandler" def isWsHandshake(req: HttpRequest): Boolean = "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && @@ -70,7 +72,7 @@ class ReactiveWebSocketHandler[F[_]]( msg match { case req: FullHttpRequest if isWsHandshake(req) => ctx.pipeline().remove(this) - ctx.pipeline().remove("readTimeoutHandler") + ctx.pipeline().remove(classOf[ReadTimeoutHandler]) ReferenceCountUtil.release(msg) val (runningFuture, _) = unsafeRunAsync { () => route(NettyServerRequest(req.retain())) @@ -91,8 +93,8 @@ class ReactiveWebSocketHandler[F[_]]( ctx .pipeline() .addAfter( - "serverCodecHandler", - "wsControlFrameHandler", + ServerCodecHandlerName, + WebSocketControlFrameHandlerName, new NettyControlFrameHandler( ignorePong = r.ignorePong, autoPongOnPing = r.autoPongOnPing, @@ -102,7 +104,7 @@ class ReactiveWebSocketHandler[F[_]]( r.autoPing.foreach { case (interval, pingMsg) => ctx .pipeline() - .addAfter("wsControlFrameHandler", "wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) + .addAfter(WebSocketControlFrameHandlerName, WebSocketAutoPingHandlerName, new WebSocketAutoPingHandler(interval, pingMsg)) } // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly r.channelPromise.setSuccess() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala index 12ac1afd26..fa0f0f99ee 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala @@ -10,4 +10,7 @@ package object internal { def toHeaderSeq: List[Header] = underlying.asScala.map(e => Header(e.getKey, e.getValue)).toList } + + val ServerCodecHandlerName = "serverCodecHandler" + val WebSocketControlFrameHandlerName = "wsControlFrameHandler" } From 2838ece36e63a07e30a485c6de84f3945f1ba76f Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 10:57:00 +0100 Subject: [PATCH 19/31] Handle hanshake for regular endpoints with 400 --- .../netty/internal/ReactiveWebSocketHandler.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index 81c81459fa..b004659c1d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty.internal +import io.netty.buffer.Unpooled import io.netty.channel.group.ChannelGroup import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} import io.netty.handler.codec.http._ @@ -104,7 +105,11 @@ class ReactiveWebSocketHandler[F[_]]( r.autoPing.foreach { case (interval, pingMsg) => ctx .pipeline() - .addAfter(WebSocketControlFrameHandlerName, WebSocketAutoPingHandlerName, new WebSocketAutoPingHandler(interval, pingMsg)) + .addAfter( + WebSocketControlFrameHandlerName, + WebSocketAutoPingHandlerName, + new WebSocketAutoPingHandler(interval, pingMsg) + ) } // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly r.channelPromise.setSuccess() @@ -121,7 +126,11 @@ class ReactiveWebSocketHandler[F[_]]( } case otherContent => logger.error(s"Unexpected response content: $otherContent") - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + val res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.BAD_REQUEST, + Unpooled.wrappedBuffer("WebSocket handshake received on a regular HTTP endpoint".getBytes) + ) res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) otherContent.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) val _ = ctx.writeAndFlush(res) From e9a213fd5c2ff766eea9675ba3ee67a4c94f3f38 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 11:45:42 +0100 Subject: [PATCH 20/31] Extract methods and reply 400 on regular endpoints --- .../internal/ReactiveWebSocketHandler.scala | 109 ++++++++++-------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index b004659c1d..68faa242b0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -13,7 +13,7 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent -import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal @@ -64,12 +64,23 @@ class ReactiveWebSocketHandler[F[_]]( } override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { - def writeError500(req: HttpRequest, reason: Throwable): Unit = { + def replyWithError500(reason: Throwable): Unit = { logger.error("Error while processing the request", reason) val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) val _ = ctx.writeAndFlush(res) } + + def rejectHandshakeForRegularEndpoint(content: NettyResponseContent): Unit = { + val message = "Unexpected WebSocket handhake on a regular HTTP endpoint" + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(message.getBytes)) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, message.length()) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + content.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) + val _ = ctx.writeAndFlush(res) + } + msg match { case req: FullHttpRequest if isWsHandshake(req) => ctx.pipeline().remove(this) @@ -90,69 +101,26 @@ class ReactiveWebSocketHandler[F[_]]( case Some(function) => { val content = function(ctx) content match { - case r: ReactiveWebSocketProcessorNettyResponseContent => { - ctx - .pipeline() - .addAfter( - ServerCodecHandlerName, - WebSocketControlFrameHandlerName, - new NettyControlFrameHandler( - ignorePong = r.ignorePong, - autoPongOnPing = r.autoPongOnPing, - decodeCloseRequests = r.decodeCloseRequests - ) - ) - r.autoPing.foreach { case (interval, pingMsg) => - ctx - .pipeline() - .addAfter( - WebSocketControlFrameHandlerName, - WebSocketAutoPingHandlerName, - new WebSocketAutoPingHandler(interval, pingMsg) - ) - } - // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly - r.channelPromise.setSuccess() - val _ = ctx.writeAndFlush( - // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler - // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) - new DefaultWebSocketHttpResponse( - req.protocolVersion(), - HttpResponseStatus.valueOf(200), - r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler - new WebSocketServerHandshakerFactory(wsUrl(req), null, false) - ) - ) - } + case r: ReactiveWebSocketProcessorNettyResponseContent => + initWsPipeline(ctx, r, req) case otherContent => - logger.error(s"Unexpected response content: $otherContent") - val res = new DefaultFullHttpResponse( - HttpVersion.HTTP_1_1, - HttpResponseStatus.BAD_REQUEST, - Unpooled.wrappedBuffer("WebSocket handshake received on a regular HTTP endpoint".getBytes) - ) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - otherContent.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) - val _ = ctx.writeAndFlush(res) + rejectHandshakeForRegularEndpoint(otherContent) } } case None => - logger.error("Missing response body, expected WebSocketProcessorNettyResponseContent") - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res) + replyWithError500(new IllegalArgumentException("Missing response body, expected WebSocketProcessorNettyResponseContent")) } Success(()) } catch { case NonFatal(ex) => - writeError500(req, ex) + replyWithError500(ex) Failure(ex) } finally { val _ = req.release() } case Failure(NonFatal(ex)) => try { - writeError500(req, ex) + replyWithError500(ex) Failure(ex) } finally { val _ = req.release() @@ -167,6 +135,45 @@ class ReactiveWebSocketHandler[F[_]]( } } + private def initWsPipeline( + ctx: ChannelHandlerContext, + r: ReactiveWebSocketProcessorNettyResponseContent, + handshakeReq: FullHttpRequest + ) = { + ctx + .pipeline() + .addAfter( + ServerCodecHandlerName, + WebSocketControlFrameHandlerName, + new NettyControlFrameHandler( + ignorePong = r.ignorePong, + autoPongOnPing = r.autoPongOnPing, + decodeCloseRequests = r.decodeCloseRequests + ) + ) + r.autoPing.foreach { case (interval, pingMsg) => + ctx + .pipeline() + .addAfter( + WebSocketControlFrameHandlerName, + WebSocketAutoPingHandlerName, + new WebSocketAutoPingHandler(interval, pingMsg) + ) + } + // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly + r.channelPromise.setSuccess() + val _ = ctx.writeAndFlush( + // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler + // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) + new DefaultWebSocketHttpResponse( + handshakeReq.protocolVersion(), + HttpResponseStatus.valueOf(200), + r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler + new WebSocketServerHandshakerFactory(wsUrl(handshakeReq), null, false) + ) + ) + } + // Only ancient WS protocol versions will use this in the response header. private def wsUrl(req: FullHttpRequest): String = { val scheme = if (isSsl) "wss" else "ws" From ae1af27adea1fcfdc3a17e32a4167e62358146e7 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Wed, 27 Mar 2024 11:48:06 +0100 Subject: [PATCH 21/31] Add capability to other methods --- .../tapir/server/netty/cats/NettyCatsServer.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index abe2b64be4..2e722b3ac4 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -8,13 +8,14 @@ import io.netty.channel._ import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress import io.netty.util.concurrent.DefaultEventExecutor +import sttp.capabilities.WebSockets import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala} -import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler, ReactiveWebSocketHandler} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} @@ -23,17 +24,15 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future import scala.concurrent.duration._ -import sttp.capabilities.WebSockets -import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { - def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) - def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = + def addEndpoint(se: ServerEndpoint[Fs2Streams[F] with WebSockets, F]): NettyCatsServer[F] = addEndpoints(List(se)) + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F] with WebSockets, overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(overrideOptions).toRoute(ses) ) From f65b0f2dfa33fac6649f3cad0880be83ea67253a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 09:00:13 +0100 Subject: [PATCH 22/31] Explicitly close the channel --- .../server/netty/internal/WebSocketControlFrameHandler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala index f8c16bfaf0..19c0846579 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -31,7 +31,9 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec val _ = ctx.fireChannelRead(close) } else { // Responding with Close immediately - val _ = ctx.writeAndFlush(close) + val _ = ctx + .writeAndFlush(close) + .addListener(_ => { val _ = ctx.close() }) } case other => val _ = ctx.fireChannelRead(other) From 76a978dbdece1404430606e349138fc9943ae10f Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 09:05:40 +0100 Subject: [PATCH 23/31] Fix typo --- .../tapir/server/netty/internal/ReactiveWebSocketHandler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index 68faa242b0..437f02e1a6 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -73,7 +73,7 @@ class ReactiveWebSocketHandler[F[_]]( } def rejectHandshakeForRegularEndpoint(content: NettyResponseContent): Unit = { - val message = "Unexpected WebSocket handhake on a regular HTTP endpoint" + val message = "Unexpected WebSocket handshake on a regular HTTP endpoint" val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(message.getBytes)) res.headers().set(HttpHeaderNames.CONTENT_LENGTH, message.length()) res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) From d811daf02fa63b41452a65cedaa9a1bb6203e32c Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 11:22:12 +0100 Subject: [PATCH 24/31] Review fixes --- .../server/akkahttp/AkkaHttpServerTest.scala | 9 ++- .../server/http4s/Http4sServerTest.scala | 9 ++- .../http4s/ztapir/ZHttp4sServerTest.scala | 9 ++- .../netty/cats/NettyCatsServerTest.scala | 9 ++- .../internal/ReactiveWebSocketHandler.scala | 30 +++++---- .../WebSocketControlFrameHandler.scala | 3 +- .../tapir/server/netty/internal/package.scala | 6 ++ .../pekkohttp/PekkoHttpServerTest.scala | 9 ++- .../tapir/server/play/PlayServerTest.scala | 9 ++- .../tapir/server/play/PlayServerTest.scala | 9 ++- .../server/tests/ServerWebSocketTests.scala | 65 +++++++++++++++++-- .../vertx/cats/CatsVertxServerTest.scala | 9 ++- .../tapir/server/vertx/VertxServerTest.scala | 9 ++- .../server/vertx/zio/ZioVertxServerTest.scala | 9 ++- .../server/ziohttp/ZioHttpServerTest.scala | 9 ++- 15 files changed, 172 insertions(+), 31 deletions(-) diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 49e0fc2aaa..8df9f48fb5 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -157,7 +157,14 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ - new ServerWebSocketTests(createServerTest, AkkaStreams, autoPing = false, failingPipe = true, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + AkkaStreams, + autoPing = false, + failingPipe = true, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 530aad401f..2bcd397e32 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -139,7 +139,14 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams[IO], autoPing = true, failingPipe = true, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + Fs2Streams[IO], + autoPing = true, + failingPipe = true, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty }.tests() ++ diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index b2cfc5ae58..55011b7f5f 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -55,7 +55,14 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = false, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = false, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 49dfc5d479..74c67cf028 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -41,7 +41,14 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ - new ServerWebSocketTests(createServerTest, Fs2Streams[IO], autoPing = true, failingPipe = true, handlePong = true) { + new ServerWebSocketTests( + createServerTest, + Fs2Streams[IO], + autoPing = true, + failingPipe = true, + handlePong = true, + rejectNonWsEndpoints = true + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty }.tests() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala index 437f02e1a6..c2e21f31d3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -2,7 +2,7 @@ package sttp.tapir.server.netty.internal import io.netty.buffer.Unpooled import io.netty.channel.group.ChannelGroup -import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.channel.{ChannelFuture, ChannelHandlerContext, ChannelInboundHandlerAdapter} import io.netty.handler.codec.http._ import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.timeout.ReadTimeoutHandler @@ -69,16 +69,22 @@ class ReactiveWebSocketHandler[F[_]]( val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res) + val _ = ctx.writeAndFlush(res).close() } - def rejectHandshakeForRegularEndpoint(content: NettyResponseContent): Unit = { + def rejectHandshakeForRegularEndpoint(content: Option[NettyResponseContent]): Unit = { val message = "Unexpected WebSocket handshake on a regular HTTP endpoint" val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(message.getBytes)) res.headers().set(HttpHeaderNames.CONTENT_LENGTH, message.length()) res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - content.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) - val _ = ctx.writeAndFlush(res) + content.foreach(_.channelPromise.setFailure(new IllegalStateException("Unexpected response content"))) + val _ = ctx.writeAndFlush(res).close() + } + + def replyNotFound(): Unit = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res).close() } msg match { @@ -94,7 +100,9 @@ class ReactiveWebSocketHandler[F[_]]( } } - val _ = runningFuture.transform { + runningFuture.onComplete { + case Success(serverResponse) if serverResponse == ServerResponse.notFound => + replyNotFound() case Success(serverResponse) => try { serverResponse.body match { @@ -104,28 +112,24 @@ class ReactiveWebSocketHandler[F[_]]( case r: ReactiveWebSocketProcessorNettyResponseContent => initWsPipeline(ctx, r, req) case otherContent => - rejectHandshakeForRegularEndpoint(otherContent) + rejectHandshakeForRegularEndpoint(Some(otherContent)) } } case None => - replyWithError500(new IllegalArgumentException("Missing response body, expected WebSocketProcessorNettyResponseContent")) + rejectHandshakeForRegularEndpoint(content = None) } - Success(()) } catch { case NonFatal(ex) => replyWithError500(ex) - Failure(ex) } finally { val _ = req.release() } - case Failure(NonFatal(ex)) => + case Failure(ex) => try { replyWithError500(ex) - Failure(ex) } finally { val _ = req.release() } - case Failure(fatalException) => Failure(fatalException) }(eventLoopContext) case other => diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala index 19c0846579..72afc98725 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -32,8 +32,7 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec } else { // Responding with Close immediately val _ = ctx - .writeAndFlush(close) - .addListener(_ => { val _ = ctx.close() }) + .writeAndFlush(close).close() } case other => val _ = ctx.fireChannelRead(other) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala index fa0f0f99ee..2615d68b8c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty +import io.netty.channel.{ChannelFuture, ChannelFutureListener} import io.netty.handler.codec.http.HttpHeaders import sttp.model.Header @@ -11,6 +12,11 @@ package object internal { underlying.asScala.map(e => Header(e.getKey, e.getValue)).toList } + implicit class RichChannelFuture(val cf: ChannelFuture) { + def close(): Unit = { + val _ = cf.addListener(ChannelFutureListener.CLOSE) + } + } val ServerCodecHandlerName = "serverCodecHandler" val WebSocketControlFrameHandlerName = "wsControlFrameHandler" } diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index a48dc86063..292bad134b 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -105,7 +105,14 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { new AllServerTests(createServerTest, interpreter, backend).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ - new ServerWebSocketTests(createServerTest, PekkoStreams, autoPing = false, failingPipe = true, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + PekkoStreams, + autoPing = false, + failingPipe = true, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index af479d230e..2d6e67fb70 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -122,7 +122,14 @@ class PlayServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, PekkoStreams, autoPing = false, failingPipe = true, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + PekkoStreams, + autoPing = false, + failingPipe = true, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 2634bc5fa9..663ee13e3a 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -117,7 +117,14 @@ class PlayServerTest extends TestSuite { new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new PlayServerWithContextTest(backend).tests() ++ - new ServerWebSocketTests(createServerTest, AkkaStreams, autoPing = false, failingPipe = true, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + AkkaStreams, + autoPing = false, + failingPipe = true, + handlePong = false, + rejectNonWsEndpoints = true + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) }.tests() ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 09faceac0a..b2b591101a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -3,14 +3,16 @@ package sttp.tapir.server.tests import cats.effect.IO import cats.syntax.all._ import io.circe.generic.auto._ -import org.scalatest.matchers.should.Matchers._ import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.{Streams, WebSockets} import sttp.client3._ +import sttp.model.StatusCode import sttp.monad.MonadError import sttp.tapir._ import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ +import sttp.tapir.model.UnsupportedWebSocketFrameException import sttp.tapir.server.interceptor.CustomiseInterceptors import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor import sttp.tapir.server.tests.ServerMetricsTest._ @@ -19,14 +21,14 @@ import sttp.tapir.tests.data.Fruit import sttp.ws.{WebSocket, WebSocketFrame} import scala.concurrent.duration._ -import sttp.tapir.model.UnsupportedWebSocketFrameException abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, S with WebSockets, OPTIONS, ROUTE], val streams: S, autoPing: Boolean, failingPipe: Boolean, - handlePong: Boolean + handlePong: Boolean, + rejectNonWsEndpoints: Boolean )(implicit m: MonadError[F] ) extends EitherValues { @@ -57,6 +59,21 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .get(baseUri.scheme("ws")) .send(backend) .map(_.body shouldBe Right(List("echo: test1", "echo: test2", Left(WebSocketFrame.Close(1000, "normal closure"))))) + }, + testServer( + endpoint.in("elsewhere").out(stringBody), + "WS handshake to a non-existing endpoint" + )((_: Unit) => pureResult("hello".asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test") + m1 <- ws.receiveText() + } yield m1 + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.code shouldBe StatusCode.NotFound) }, { val reqCounter = newRequestCounter[F] @@ -202,9 +219,11 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( .response(asString) .get(baseUri.scheme("http")) .send(backend) - .map(_.body shouldBe Left("Not a WS!")) + .map { r => + r.body shouldBe Left("Not a WS!") + } } - ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests + ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests ++ rejectNonWsEndpointsTests val autoPingTests = if (autoPing) @@ -316,4 +335,40 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( } ) else List.empty + + val rejectNonWsEndpointsTests = + if (rejectNonWsEndpoints) + List( + testServer( + endpoint.out(stringBody), + "WS handshake to a non-WS endpoint" + )((_: Unit) => pureResult("hello".asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test") + m1 <- ws.receiveText() + } yield List(m1) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.code shouldBe StatusCode.BadRequest) + }, + testServer( + endpoint.out(emptyOutput), + "WS handshake to a non-WS endpoint with empty output" // to make sure this won't be treated as 404 + )((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test") + m1 <- ws.receiveText() + } yield List(m1) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.code shouldBe StatusCode.BadRequest) + } + ) + else List.empty } diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index aac2e40122..70191e4241 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -37,7 +37,14 @@ class CatsVertxServerTest extends TestSuite { partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams.apply[IO])(drainFs2) ++ - new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO], autoPing = false, failingPipe = true, handlePong = true) { + new ServerWebSocketTests( + createServerTest, + Fs2Streams.apply[IO], + autoPing = false, + failingPipe = true, + handlePong = true, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty }.tests() diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index 394a8d22a2..cf5c6157aa 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -53,7 +53,14 @@ class VertxServerTest extends TestSuite { partContentTypeHeaderSupport = true, partOtherHeaderSupport = false ).tests() ++ new ServerStreamingTests(createServerTest).tests(VertxStreams)(drainVertx[Buffer]) ++ - (new ServerWebSocketTests(createServerTest, VertxStreams, autoPing = false, failingPipe = false, handlePong = true) { + (new ServerWebSocketTests( + createServerTest, + VertxStreams, + autoPing = false, + failingPipe = false, + handlePong = true, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() }).tests() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index f2ca8846c5..3f112f5ba7 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -43,7 +43,14 @@ class ZioVertxServerTest extends TestSuite with OptionValues { partOtherHeaderSupport = false ).tests() ++ additionalTests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ - new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = true, handlePong = true) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = true, + handlePong = true, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty }.tests() diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 7e6043166a..e0b71f21b1 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -266,7 +266,14 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ - new ServerWebSocketTests(createServerTest, ZioStreams, autoPing = true, failingPipe = false, handlePong = false) { + new ServerWebSocketTests( + createServerTest, + ZioStreams, + autoPing = true, + failingPipe = false, + handlePong = false, + rejectNonWsEndpoints = false + ) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ From 8e2efc72743f1a5606cd006631a3ba570bffc597 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 11:31:38 +0100 Subject: [PATCH 25/31] Move ws-specific stuff to its own package --- .../tapir/server/netty/cats/NettyCatsServer.scala | 8 ++++++-- .../cats/internal/WebSocketPipeProcessor.scala | 2 +- .../sttp/tapir/server/netty/NettyFutureServer.scala | 2 +- .../server/netty/internal/NettyBootstrap.scala | 1 + .../{ => ws}/ReactiveWebSocketHandler.scala | 3 ++- .../{ => ws}/WebSocketAutoPingHandler.scala | 2 +- .../{ => ws}/WebSocketControlFrameHandler.scala | 13 ++++++------- .../{ => ws}/WebSocketFrameConverters.scala | 2 +- 8 files changed, 19 insertions(+), 14 deletions(-) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{ => ws}/ReactiveWebSocketHandler.scala (98%) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{ => ws}/WebSocketAutoPingHandler.scala (97%) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{ => ws}/WebSocketControlFrameHandler.scala (73%) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/{ => ws}/WebSocketFrameConverters.scala (97%) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 2e722b3ac4..6863559557 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -15,7 +15,8 @@ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala} -import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler, ReactiveWebSocketHandler} +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.internal.ws.ReactiveWebSocketHandler import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} @@ -32,7 +33,10 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( + def addEndpoints( + ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]], + overrideOptions: NettyCatsServerOptions[F] + ): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(overrideOptions).toRoute(ses) ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 441d5d60ef..d6d08100ab 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -13,7 +13,7 @@ import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription} import org.slf4j.LoggerFactory import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.model.WebSocketFrameDecodeFailure -import sttp.tapir.server.netty.internal.WebSocketFrameConverters._ +import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._ import sttp.tapir.{DecodeResult, WebSocketBodyOutput} import sttp.ws.WebSocketFrame diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 2c972c3226..338a3e19e8 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -9,6 +9,7 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.internal.FutureUtil._ import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.internal.ws.ReactiveWebSocketHandler import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} @@ -16,7 +17,6 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ExecutionContext, Future, blocking} -import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 5c864b3fa9..228a8892ba 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -4,6 +4,7 @@ import io.netty.bootstrap.ServerBootstrap import io.netty.channel.{Channel, ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.handler.timeout.ReadTimeoutHandler import sttp.tapir.server.netty.NettyConfig +import ws.ReactiveWebSocketHandler import java.net.{InetSocketAddress, SocketAddress} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala similarity index 98% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala index c2e21f31d3..e2c9c7bff0 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala @@ -1,4 +1,4 @@ -package sttp.tapir.server.netty.internal +package sttp.tapir.server.netty.internal.ws import io.netty.buffer.Unpooled import io.netty.channel.group.ChannelGroup @@ -13,6 +13,7 @@ import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent +import sttp.tapir.server.netty.internal._ import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} import scala.concurrent.{ExecutionContext, Future} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala similarity index 97% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala index c62eb6ba9a..49838d1f59 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketAutoPingHandler.scala @@ -1,4 +1,4 @@ -package sttp.tapir.server.netty.internal +package sttp.tapir.server.netty.internal.ws import io.netty.buffer.Unpooled import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala similarity index 73% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala index 72afc98725..04cfb931be 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala @@ -1,10 +1,8 @@ -package sttp.tapir.server.netty.internal +package sttp.tapir.server.netty.internal.ws -import io.netty.channel.ChannelInboundHandlerAdapter -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.http.websocketx.PingWebSocketFrame -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, PingWebSocketFrame, PongWebSocketFrame} +import sttp.tapir.server.netty.internal._ /** Handles Ping, Pong, and Close frames for WebSockets. */ @@ -32,7 +30,8 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec } else { // Responding with Close immediately val _ = ctx - .writeAndFlush(close).close() + .writeAndFlush(close) + .close() } case other => val _ = ctx.fireChannelRead(other) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala similarity index 97% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala index e2d866aeac..d9e7c75cd1 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala @@ -1,4 +1,4 @@ -package sttp.tapir.server.netty.internal +package sttp.tapir.server.netty.internal.ws import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.handler.codec.http.websocketx._ From 9cf4e42b708c442fe2d255bf1122f2b84c6d47a4 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 12:11:31 +0100 Subject: [PATCH 26/31] Add handlers in a list --- .../server/netty/cats/NettyCatsServer.scala | 6 ++++-- .../cats/internal/Fs2StreamCompatible.scala | 6 +++--- .../tapir/server/netty/loom/NettyIdServer.scala | 16 ++++++++-------- .../sttp/tapir/server/netty/NettyConfig.scala | 12 +++++------- .../tapir/server/netty/NettyFutureServer.scala | 3 +-- .../server/netty/internal/NettyBootstrap.scala | 11 ++++------- .../tapir/server/netty/zio/NettyZioServer.scala | 17 +++++++++-------- 7 files changed, 34 insertions(+), 37 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 6863559557..0c18fdefa1 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -79,8 +79,10 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), - new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), + List( + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader) + ), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 1b5ea432a7..8f257f3559 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -71,8 +71,8 @@ object Fs2StreamCompatible { o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], ctx: ChannelHandlerContext ): Processor[WebSocketFrame, WebSocketFrame] = { - val processorPromise = ctx.newPromise() - processorPromise.addListener((f: ChannelFuture) => { + val wsCompletedPromise = ctx.newPromise() + wsCompletedPromise.addListener((f: ChannelFuture) => { // A special callback that has to be used when a SteramSubscription cancels or fails. // This can happen in case of errors in the pipeline which are not signalled correctly, // like throwing exceptions directly. @@ -83,7 +83,7 @@ object Fs2StreamCompatible { val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Error")) } }) - new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, processorPromise) + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, wsCompletedPromise) } private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 2de073f763..53eb3b4d56 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -25,7 +25,6 @@ import scala.concurrent.Future import scala.concurrent.Promise import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal -import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) { private val executor = Executors.newVirtualThreadPerTaskExecutor() @@ -94,14 +93,15 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, val channelIdFuture = NettyBootstrap( config, - new NettyServerHandler( - route, - unsafeRunF, - channelGroup, - isShuttingDown, - config.serverHeader + List( + new NettyServerHandler( + route, + unsafeRunF, + channelGroup, + isShuttingDown, + config.serverHeader + ) ), - new ReactiveWebSocketHandler(route, channelGroup, unsafeRunF, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 115c387be1..05f7cd165e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -62,7 +62,7 @@ case class NettyConfig( sslContext: Option[SslContext], eventLoopConfig: EventLoopConfig, socketConfig: NettySocketConfig, - initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit, + initPipeline: NettyConfig => (ChannelPipeline, List[ChannelHandler]) => Unit, gracefulShutdownTimeout: Option[FiniteDuration], serverHeader: Option[String] ) { @@ -97,7 +97,7 @@ case class NettyConfig( def eventLoopConfig(elc: EventLoopConfig): NettyConfig = copy(eventLoopConfig = elc) def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) - def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) + def initPipeline(f: NettyConfig => (ChannelPipeline, List[ChannelHandler]) => Unit): NettyConfig = copy(initPipeline = f) def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) def noGracefulShutdown = copy(gracefulShutdownTimeout = None) @@ -106,7 +106,6 @@ case class NettyConfig( } object NettyConfig { - private val WebSocketHandlerName = "wsHandler" val StreamsHandlerName = "streamsHandler" def default: NettyConfig = NettyConfig( @@ -124,16 +123,15 @@ object NettyConfig { sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipeline(cfg)(_, _, _), + initPipeline = cfg => defaultInitPipeline(cfg)(_, _), serverHeader = Some(s"tapir/${buildinfo.BuildInfo.version}") ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler, wsHandler: ChannelHandler): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handlers: List[ChannelHandler]): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(ServerCodecHandlerName, new HttpServerCodec()) pipeline.addLast(StreamsHandlerName, new HttpStreamsServerHandler()) - pipeline.addLast(WebSocketHandlerName, wsHandler) - pipeline.addLast(handler) + handlers.foreach(pipeline.addLast(_)) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 338a3e19e8..0a5ad68100 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -72,8 +72,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), - new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), + List(new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader)), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 228a8892ba..45f4e3b348 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -1,10 +1,9 @@ package sttp.tapir.server.netty.internal import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{Channel, ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{Channel, ChannelFuture, ChannelHandler, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.handler.timeout.ReadTimeoutHandler import sttp.tapir.server.netty.NettyConfig -import ws.ReactiveWebSocketHandler import java.net.{InetSocketAddress, SocketAddress} @@ -14,8 +13,7 @@ object NettyBootstrap { def apply[F[_]]( nettyConfig: NettyConfig, - handler: => NettyServerHandler[F], - wsHandler: => ReactiveWebSocketHandler[F], + handlers: => List[ChannelHandler], eventLoopGroup: EventLoopGroup, overrideSocketAddress: Option[SocketAddress] ): ChannelFuture = { @@ -33,10 +31,9 @@ object NettyBootstrap { case Some(requestTimeout) => nettyConfigBuilder( ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), - handler, - wsHandler + handlers, ) - case None => nettyConfigBuilder(ch.pipeline(), handler, wsHandler) + case None => nettyConfigBuilder(ch.pipeline(), handlers) } connectionCounterOpt.map(counter => { diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index da8c9082d3..093b3bd966 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -7,7 +7,7 @@ import io.netty.util.concurrent.DefaultEventExecutor import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler, ReactiveWebSocketHandler} +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} @@ -86,14 +86,15 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyBootstrap[RIO[R, *]]( config, - new NettyServerHandler[RIO[R, *]]( - route, - unsafeRunAsync(runtime), - channelGroup, - isShuttingDown, - config.serverHeader + List( + new NettyServerHandler[RIO[R, *]]( + route, + unsafeRunAsync(runtime), + channelGroup, + isShuttingDown, + config.serverHeader + ) ), - new ReactiveWebSocketHandler[RIO[R, *]](route, channelGroup, unsafeRunAsync(runtime), config.sslContext.isDefined), eventLoopGroup, socketOverride ) From b9db8dd48d57c68433fb0a51034caf20c08e29a2 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 16:21:59 +0100 Subject: [PATCH 27/31] More review fixes --- .../server/akkahttp/AkkaHttpServerTest.scala | 3 +- .../server/http4s/Http4sServerTest.scala | 3 +- .../http4s/ztapir/ZHttp4sServerTest.scala | 3 +- .../server/netty/cats/NettyCatsServer.scala | 9 +- .../internal/WebSocketPipeProcessor.scala | 29 ++--- .../netty/cats/NettyCatsServerTest.scala | 3 +- .../tapir/server/netty/internal/package.scala | 14 +- .../ws/ReactiveWebSocketHandler.scala | 120 +++++++++++------- .../pekkohttp/PekkoHttpServerTest.scala | 3 +- .../tapir/server/play/PlayServerTest.scala | 3 +- .../tapir/server/play/PlayServerTest.scala | 3 +- .../server/tests/ServerWebSocketTests.scala | 94 +++++++------- .../vertx/cats/CatsVertxServerTest.scala | 3 +- .../tapir/server/vertx/VertxServerTest.scala | 3 +- .../server/vertx/zio/ZioVertxServerTest.scala | 3 +- .../server/ziohttp/ZioHttpServerTest.scala | 3 +- 16 files changed, 163 insertions(+), 136 deletions(-) diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index 8df9f48fb5..db589bce84 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -162,8 +162,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { AkkaStreams, autoPing = false, failingPipe = true, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala index 2bcd397e32..7f6028b33a 100644 --- a/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala +++ b/server/http4s-server/src/test/scala/sttp/tapir/server/http4s/Http4sServerTest.scala @@ -144,8 +144,7 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi Fs2Streams[IO], autoPing = true, failingPipe = true, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty diff --git a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala index 55011b7f5f..5aafe6caa7 100644 --- a/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala +++ b/server/http4s-server/zio/src/test/scala/sttp/tapir/server/http4s/ztapir/ZHttp4sServerTest.scala @@ -60,8 +60,7 @@ class ZHttp4sServerTest extends TestSuite with OptionValues { ZioStreams, autoPing = true, failingPipe = false, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 0c18fdefa1..38d59d6ebc 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -80,7 +80,14 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyBootstrap( config, List( - new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), + new ReactiveWebSocketHandler( + route, + channelGroup, + unsafeRunAsync, + config.sslContext.isDefined, + isShuttingDown, + config.serverHeader + ), new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader) ), eventLoopGroup, diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index d6d08100ab..3147ffc7af 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -1,8 +1,8 @@ package sttp.tapir.server.netty.cats.internal import cats.Applicative -import cats.effect.kernel.{Async, Sync} import cats.effect.kernel.Resource.ExitCase +import cats.effect.kernel.{Async, Sync} import cats.effect.std.Dispatcher import cats.syntax.all._ import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} @@ -17,9 +17,8 @@ import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._ import sttp.tapir.{DecodeResult, WebSocketBodyOutput} import sttp.ws.WebSocketFrame -import scala.concurrent.ExecutionContext.Implicits -import scala.concurrent.Promise -import scala.util.{Failure, Success} +import scala.concurrent.{ExecutionContext, Promise} +import scala.util.Success /** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. */ @@ -29,19 +28,16 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]], wsCompletedPromise: ChannelPromise ) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { - @volatile private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ + // Not really that unsafe. Subscriber creation doesn't do any IO, only initializes an AtomicReference in an initial state. + private val subscriber: StreamSubscriber[F, NettyWebSocketFrame] = dispatcher.unsafeRunSync( + // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered + StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) + ) private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() - @volatile private var subscription: Subscription = _ - private val logger = LoggerFactory.getLogger(getClass.getName) override def onSubscribe(s: Subscription): Unit = { - // Not really that unsafe. Subscriber creation doesn't do any IO, only initializes an AtomicReference in an initial state. - subscriber = dispatcher.unsafeRunSync( - // If bufferSize > 1, the stream may stale and not emit responses until enough requests are buffered - StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) - ) - subscription = new NonCancelingSubscription(s) + val subscription = new NonCancelingSubscription(s) val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) val sttpFrames = in.map { f => val sttpFrame = nettyFrameToFrame(f) @@ -97,10 +93,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( publisher.future.onComplete { case Success(p) => p.subscribe(s) - case Failure(ex) => - subscriber.sub.onError(ex) - subscription.cancel - }(Implicits.global) + case _ => // Never happens, we call succecss() explicitly + }(ExecutionContext.parasitic) + } private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 74c67cf028..73053ed362 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -46,8 +46,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { Fs2Streams[IO], autoPing = true, failingPipe = true, - handlePong = true, - rejectNonWsEndpoints = true + handlePong = true ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala index 2615d68b8c..e8cbf4d79c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala @@ -1,8 +1,9 @@ package sttp.tapir.server.netty import io.netty.channel.{ChannelFuture, ChannelFutureListener} -import io.netty.handler.codec.http.HttpHeaders +import io.netty.handler.codec.http.{HttpHeaderNames, HttpHeaders, HttpMessage} import sttp.model.Header +import sttp.tapir.server.model.ServerResponse import scala.collection.JavaConverters._ @@ -17,6 +18,17 @@ package object internal { val _ = cf.addListener(ChannelFutureListener.CLOSE) } } + + implicit class RichHttpMessage(val m: HttpMessage) { + def setHeadersFrom(response: ServerResponse[_], serverHeader: Option[String]): Unit = { + serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) + response.headers + .groupBy(_.name) + .foreach { case (k, v) => + m.headers().set(k, v.map(_.value).asJava) + } + } + } val ServerCodecHandlerName = "serverCodecHandler" val WebSocketControlFrameHandlerName = "wsControlFrameHandler" } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala index e2c9c7bff0..ef1cf87a34 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala @@ -19,6 +19,10 @@ import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServer import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success} +import sttp.tapir.server.netty.NettyResponseContent.ByteBufNettyResponseContent +import io.netty.channel.ChannelPromise +import io.netty.buffer.ByteBuf +import java.util.concurrent.atomic.AtomicBoolean /** Handles a WS handshake and initiates the communication by calling Tapir interpreter to get a Pipe, then sends that Pipe to the rest of * the processing pipeline and removes itself from the pipeline. @@ -27,7 +31,9 @@ class ReactiveWebSocketHandler[F[_]]( route: Route[F], channelGroup: ChannelGroup, unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), - isSsl: Boolean + isSsl: Boolean, + isShuttingDown: AtomicBoolean, + serverHeader: Option[String] )(implicit m: MonadError[F]) extends ChannelInboundHandlerAdapter { @@ -73,12 +79,10 @@ class ReactiveWebSocketHandler[F[_]]( val _ = ctx.writeAndFlush(res).close() } - def rejectHandshakeForRegularEndpoint(content: Option[NettyResponseContent]): Unit = { - val message = "Unexpected WebSocket handshake on a regular HTTP endpoint" - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(message.getBytes)) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, message.length()) + def replyWith503(): Unit = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - content.foreach(_.channelPromise.setFailure(new IllegalStateException("Unexpected response content"))) val _ = ctx.writeAndFlush(res).close() } @@ -88,54 +92,74 @@ class ReactiveWebSocketHandler[F[_]]( val _ = ctx.writeAndFlush(res).close() } + def replyWithRouteResponse( + req: HttpRequest, + serverResponse: ServerResponse[NettyResponse], + channelPromise: ChannelPromise, + byteBuf: ByteBuf + ): Unit = { + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + res.setHeadersFrom(serverResponse, serverHeader) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, byteBuf.readableBytes()) + val _ = ctx.writeAndFlush(res, channelPromise) + } + msg match { case req: FullHttpRequest if isWsHandshake(req) => - ctx.pipeline().remove(this) - ctx.pipeline().remove(classOf[ReadTimeoutHandler]) - ReferenceCountUtil.release(msg) - val (runningFuture, _) = unsafeRunAsync { () => - route(NettyServerRequest(req.retain())) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - } - - runningFuture.onComplete { - case Success(serverResponse) if serverResponse == ServerResponse.notFound => - replyNotFound() - case Success(serverResponse) => - try { - serverResponse.body match { - case Some(function) => { - val content = function(ctx) - content match { - case r: ReactiveWebSocketProcessorNettyResponseContent => - initWsPipeline(ctx, r, req) - case otherContent => - rejectHandshakeForRegularEndpoint(Some(otherContent)) + if (isShuttingDown.get()) { + logger.info("Rejecting WS handshake request because the server is shutting down.") + replyWith503() + } else { + ReferenceCountUtil.release(msg) + val (runningFuture, _) = unsafeRunAsync { () => + route(NettyServerRequest(req.retain())) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + } + + runningFuture.onComplete { + case Success(serverResponse) if serverResponse == ServerResponse.notFound => + replyNotFound() + case Success(serverResponse) => + try { + serverResponse.body match { + case Some(function) => { + val content = function(ctx) + content match { + case r: ReactiveWebSocketProcessorNettyResponseContent => + initWsPipeline(ctx, r, req) + case ByteBufNettyResponseContent(channelPromise, byteBuf) => + // Handshake didn't return a Pipe, but a regular response. Returning it back. + replyWithRouteResponse(req, serverResponse, channelPromise, byteBuf) + case otherContent => + // An unsupported type of regular response, returning 500 + replyWithError500( + new IllegalArgumentException(s"Unsupported response type for a WS endpoint: ${otherContent.getClass.getName}") + ) + } } + case None => + // Handshake didn't return a Pipe, but a regular response with empty body. Returning it back. + replyWithRouteResponse(req, serverResponse, ctx.newPromise(), Unpooled.EMPTY_BUFFER) } - case None => - rejectHandshakeForRegularEndpoint(content = None) + } catch { + case NonFatal(ex) => + replyWithError500(ex) + } finally { + val _ = req.release() } - } catch { - case NonFatal(ex) => + case Failure(ex) => + try { replyWithError500(ex) - } finally { - val _ = req.release() - } - case Failure(ex) => - try { - replyWithError500(ex) - } finally { - val _ = req.release() - } - }(eventLoopContext) - + } finally { + val _ = req.release() + } + }(eventLoopContext) + } case other => - // not a WS handshake, from now on process messages as normal HTTP requests in this channel - ctx.pipeline.remove(this) + // not a WS handshake val _ = ctx.fireChannelRead(other) } } @@ -145,6 +169,8 @@ class ReactiveWebSocketHandler[F[_]]( r: ReactiveWebSocketProcessorNettyResponseContent, handshakeReq: FullHttpRequest ) = { + ctx.pipeline().remove(this) + ctx.pipeline().remove(classOf[ReadTimeoutHandler]) ctx .pipeline() .addAfter( diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index 292bad134b..edf58cb4db 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -110,8 +110,7 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { PekkoStreams, autoPing = false, failingPipe = true, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 2d6e67fb70..7ef2dc432d 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -127,8 +127,7 @@ class PlayServerTest extends TestSuite { PekkoStreams, autoPing = false, failingPipe = true, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 663ee13e3a..621feb18d3 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -122,8 +122,7 @@ class PlayServerTest extends TestSuite { AkkaStreams, autoPing = false, failingPipe = true, - handlePong = false, - rejectNonWsEndpoints = true + handlePong = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f) override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index b2b591101a..7ec6a4a2bc 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -27,8 +27,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( val streams: S, autoPing: Boolean, failingPipe: Boolean, - handlePong: Boolean, - rejectNonWsEndpoints: Boolean + handlePong: Boolean )(implicit m: MonadError[F] ) extends EitherValues { @@ -210,20 +209,55 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( }, testServer( endpoint - .in(isWebSocket) + .in(query[String]("securityToken")) .errorOut(stringBody) .out(stringWs), - "non web-socket request" - )(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) => - basicRequest - .response(asString) - .get(baseUri.scheme("http")) - .send(backend) - .map { r => - r.body shouldBe Left("Not a WS!") + "switch to WS after a normal HTTP request" + )(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) { + (backend, baseUri) => + for { + response1 <- basicRequest + .response(asString) + .get(uri"$baseUri?securityToken=wrong".scheme("http")) + .send(backend) + response2 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testOk") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=correctToken".scheme("ws")) + .send(backend) + } yield { + response1.body shouldBe Left("Incorrect token!") + response2.body shouldBe Right("echo: testOk") + } + }, + testServer( + endpoint + .in(query[String]("securityToken")) + .errorOut(stringBody) + .out(stringWs), + "reject WS handshake, then accept a corrected one" + )(token => if (token == "correctToken") pureResult(stringEcho.asRight) else pureResult("Incorrect token!".asLeft)) { + (backend, baseUri) => + for { + response1 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testWrong") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=wrong".scheme("ws")) + .send(backend) + response2 <- basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + ws.sendText("testOk") >> ws.receiveText() + }) + .get(uri"$baseUri?securityToken=correctToken".scheme("ws")) + .send(backend) + } yield { + response1.code shouldBe StatusCode.BadRequest + response2.body shouldBe Right("echo: testOk") } } - ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests ++ rejectNonWsEndpointsTests + ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests val autoPingTests = if (autoPing) @@ -335,40 +369,4 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( } ) else List.empty - - val rejectNonWsEndpointsTests = - if (rejectNonWsEndpoints) - List( - testServer( - endpoint.out(stringBody), - "WS handshake to a non-WS endpoint" - )((_: Unit) => pureResult("hello".asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("test") - m1 <- ws.receiveText() - } yield List(m1) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map(_.code shouldBe StatusCode.BadRequest) - }, - testServer( - endpoint.out(emptyOutput), - "WS handshake to a non-WS endpoint with empty output" // to make sure this won't be treated as 404 - )((_: Unit) => pureResult(().asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.sendText("test") - m1 <- ws.receiveText() - } yield List(m1) - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map(_.code shouldBe StatusCode.BadRequest) - } - ) - else List.empty } diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index 70191e4241..9dfd570fa7 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -42,8 +42,7 @@ class CatsVertxServerTest extends TestSuite { Fs2Streams.apply[IO], autoPing = false, failingPipe = true, - handlePong = true, - rejectNonWsEndpoints = false + handlePong = true ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index cf5c6157aa..bb6294924a 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -58,8 +58,7 @@ class VertxServerTest extends TestSuite { VertxStreams, autoPing = false, failingPipe = false, - handlePong = true, - rejectNonWsEndpoints = false + handlePong = true ) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index 3f112f5ba7..a8d79dcb5b 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -48,8 +48,7 @@ class ZioVertxServerTest extends TestSuite with OptionValues { ZioStreams, autoPing = true, failingPipe = true, - handlePong = true, - rejectNonWsEndpoints = false + handlePong = true ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index e0b71f21b1..417e16a5af 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -271,8 +271,7 @@ class ZioHttpServerTest extends TestSuite { ZioStreams, autoPing = true, failingPipe = false, - handlePong = false, - rejectNonWsEndpoints = false + handlePong = false ) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty From ba7fb913fc538fd0156deee11d23f7d463a7d721 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Thu, 28 Mar 2024 16:46:43 +0100 Subject: [PATCH 28/31] Use parasitic EC --- .../netty/cats/internal/ExecutionContexts.scala | 12 ++++++++++++ .../netty/cats/internal/ExecutionContexts.scala | 7 +++++++ .../netty/cats/internal/WebSocketPipeProcessor.scala | 4 ++-- 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala create mode 100644 server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala diff --git a/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala b/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala new file mode 100644 index 0000000000..f88a74ed12 --- /dev/null +++ b/server/netty-server/cats/src/main/scala-2.12/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala @@ -0,0 +1,12 @@ +package sttp.tapir.server.netty.cats.internal + +import scala.concurrent.ExecutionContext + +object ExecutionContexts { + val sameThread: ExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = runnable.run() + + override def reportFailure(cause: Throwable): Unit = + ExecutionContext.defaultReporter(cause) + } +} diff --git a/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala b/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala new file mode 100644 index 0000000000..dacc9e5473 --- /dev/null +++ b/server/netty-server/cats/src/main/scala-3-2.13+/sttp/tapir/server/netty/cats/internal/ExecutionContexts.scala @@ -0,0 +1,7 @@ +package sttp.tapir.server.netty.cats.internal + +import scala.concurrent.ExecutionContext + +object ExecutionContexts { + val sameThread: ExecutionContext = ExecutionContext.parasitic +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 3147ffc7af..26e50651de 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -17,7 +17,7 @@ import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._ import sttp.tapir.{DecodeResult, WebSocketBodyOutput} import sttp.ws.WebSocketFrame -import scala.concurrent.{ExecutionContext, Promise} +import scala.concurrent.Promise import scala.util.Success /** A Reactive Streams Processor[NettyWebSocketFrame, NettyWebSocketFrame] built from a fs2.Pipe[F, REQ, RESP] passed from an WS endpoint. @@ -94,7 +94,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( case Success(p) => p.subscribe(s) case _ => // Never happens, we call succecss() explicitly - }(ExecutionContext.parasitic) + }(ExecutionContexts.sameThread) } From ae656093e4b81de40b8e31c10fab40f37b0ca1f9 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 29 Mar 2024 13:17:56 +0100 Subject: [PATCH 29/31] Improvements after code review --- .../tapir/server/netty/internal/NettyServerHandler.scala | 6 +++--- .../scala/sttp/tapir/server/netty/internal/package.scala | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 1ec09f7d27..33d069f6b8 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -236,7 +236,7 @@ class NettyServerHandler[F[_]]( } ) - private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { + private implicit class RichServerNettyResponse(r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, byteBufHandler: (ChannelPromise, ByteBuf) => Unit, @@ -263,7 +263,7 @@ class NettyServerHandler[F[_]]( } } - private implicit class RichHttpMessage(val m: HttpMessage) { + private implicit class RichHttpMessage(m: HttpMessage) { def setHeadersFrom(response: ServerResponse[_]): Unit = { serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) response.headers @@ -289,7 +289,7 @@ class NettyServerHandler[F[_]]( } } - private implicit class RichChannelFuture(val cf: ChannelFuture) { + private implicit class RichChannelFuture(cf: ChannelFuture) { def closeIfNeeded(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request) || isShuttingDown.get()) { cf.addListener(ChannelFutureListener.CLOSE) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala index e8cbf4d79c..4a4659d14e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/package.scala @@ -8,7 +8,7 @@ import sttp.tapir.server.model.ServerResponse import scala.collection.JavaConverters._ package object internal { - implicit class RichNettyHttpHeaders(underlying: HttpHeaders) { + implicit class RichNettyHttpHeaders(private val underlying: HttpHeaders) extends AnyVal { def toHeaderSeq: List[Header] = underlying.asScala.map(e => Header(e.getKey, e.getValue)).toList } @@ -19,7 +19,7 @@ package object internal { } } - implicit class RichHttpMessage(val m: HttpMessage) { + implicit class RichHttpMessage(private val m: HttpMessage) extends AnyVal { def setHeadersFrom(response: ServerResponse[_], serverHeader: Option[String]): Unit = { serverHeader.foreach(m.headers().set(HttpHeaderNames.SERVER, _)) response.headers @@ -29,6 +29,6 @@ package object internal { } } } - val ServerCodecHandlerName = "serverCodecHandler" - val WebSocketControlFrameHandlerName = "wsControlFrameHandler" + final val ServerCodecHandlerName = "serverCodecHandler" + final val WebSocketControlFrameHandlerName = "wsControlFrameHandler" } From d2a75d0d5f295764473920ec3c573820d383d72a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 29 Mar 2024 15:21:45 +0100 Subject: [PATCH 30/31] Handle handshake in NettyServerHandler --- .../server/netty/cats/NettyCatsServer.scala | 13 +- .../server/netty/loom/NettyIdServer.scala | 15 +- .../sttp/tapir/server/netty/NettyConfig.scala | 10 +- .../server/netty/NettyFutureServer.scala | 3 +- .../netty/internal/NettyBootstrap.scala | 6 +- .../netty/internal/NettyServerHandler.scala | 80 ++++++- .../ws/ReactiveWebSocketHandler.scala | 213 ------------------ .../server/netty/zio/NettyZioServer.scala | 15 +- 8 files changed, 94 insertions(+), 261 deletions(-) delete mode 100644 server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 38d59d6ebc..8ebe94f74c 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -16,7 +16,6 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.cats.internal.CatsUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} -import sttp.tapir.server.netty.internal.ws.ReactiveWebSocketHandler import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import java.net.{InetSocketAddress, SocketAddress} @@ -79,17 +78,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - List( - new ReactiveWebSocketHandler( - route, - channelGroup, - unsafeRunAsync, - config.sslContext.isDefined, - isShuttingDown, - config.serverHeader - ), - new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader) - ), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 53eb3b4d56..85fc5bdd26 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -93,14 +93,13 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, val channelIdFuture = NettyBootstrap( config, - List( - new NettyServerHandler( - route, - unsafeRunF, - channelGroup, - isShuttingDown, - config.serverHeader - ) + new NettyServerHandler( + route, + unsafeRunF, + channelGroup, + isShuttingDown, + config.serverHeader, + config.isSsl ), eventLoopGroup, socketOverride diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 05f7cd165e..e371851211 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -62,7 +62,7 @@ case class NettyConfig( sslContext: Option[SslContext], eventLoopConfig: EventLoopConfig, socketConfig: NettySocketConfig, - initPipeline: NettyConfig => (ChannelPipeline, List[ChannelHandler]) => Unit, + initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit, gracefulShutdownTimeout: Option[FiniteDuration], serverHeader: Option[String] ) { @@ -97,12 +97,14 @@ case class NettyConfig( def eventLoopConfig(elc: EventLoopConfig): NettyConfig = copy(eventLoopConfig = elc) def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) - def initPipeline(f: NettyConfig => (ChannelPipeline, List[ChannelHandler]) => Unit): NettyConfig = copy(initPipeline = f) + def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) def noGracefulShutdown = copy(gracefulShutdownTimeout = None) def serverHeader(h: String): NettyConfig = copy(serverHeader = Some(h)) + + def isSsl: Boolean = sslContext.isDefined } object NettyConfig { @@ -127,11 +129,11 @@ object NettyConfig { serverHeader = Some(s"tapir/${buildinfo.BuildInfo.version}") ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handlers: List[ChannelHandler]): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(ServerCodecHandlerName, new HttpServerCodec()) pipeline.addLast(StreamsHandlerName, new HttpStreamsServerHandler()) - handlers.foreach(pipeline.addLast(_)) + pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index 0a5ad68100..0d4285f4e5 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -9,7 +9,6 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.internal.FutureUtil._ import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} -import sttp.tapir.server.netty.internal.ws.ReactiveWebSocketHandler import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} @@ -72,7 +71,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val channelFuture = NettyBootstrap( config, - List(new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader)), + new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 45f4e3b348..07edb5ace2 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -13,7 +13,7 @@ object NettyBootstrap { def apply[F[_]]( nettyConfig: NettyConfig, - handlers: => List[ChannelHandler], + handler: => NettyServerHandler[F], eventLoopGroup: EventLoopGroup, overrideSocketAddress: Option[SocketAddress] ): ChannelFuture = { @@ -31,9 +31,9 @@ object NettyBootstrap { case Some(requestTimeout) => nettyConfigBuilder( ch.pipeline().addLast(ReadTimeoutHandlerName, new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), - handlers, + handler, ) - case None => nettyConfigBuilder(ch.pipeline(), handlers) + case None => nettyConfigBuilder(ch.pipeline(), handler) } connectionCounterOpt.map(counter => { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 33d069f6b8..8ce0004b9a 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,8 +4,10 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} -import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} +import io.netty.handler.timeout.ReadTimeoutHandler +import org.playframework.netty.http.{DefaultStreamedHttpResponse, DefaultWebSocketHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher import org.slf4j.LoggerFactory import sttp.monad.MonadError @@ -18,6 +20,7 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ReactivePublisherNettyResponseContent, ReactiveWebSocketProcessorNettyResponseContent } +import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler} import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import java.util.concurrent.atomic.AtomicBoolean @@ -37,7 +40,8 @@ class NettyServerHandler[F[_]]( unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), channelGroup: ChannelGroup, isShuttingDown: AtomicBoolean, - serverHeader: Option[String] + serverHeader: Option[String], + isSsl: Boolean = false )(implicit me: MonadError[F] ) extends SimpleChannelInboundHandler[HttpRequest] { @@ -63,6 +67,7 @@ class NettyServerHandler[F[_]]( private[this] val pendingResponses = MutableQueue.empty[() => Future[Unit]] private val logger = LoggerFactory.getLogger(getClass.getName) + private final val WebSocketAutoPingHandlerName = "wsAutoPingHandler" override def handlerAdded(ctx: ChannelHandlerContext): Unit = if (ctx.channel.isActive) { @@ -213,13 +218,16 @@ class NettyServerHandler[F[_]]( ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) }, - wsHandler = (channelPromise) => { - logger.error( - "Unexpected WebSocket processor response received in NettyServerHandler, it should be handled only in the ReactiveWebSocketHandler" - ) - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res) + wsHandler = (responseContent) => { + if (isWsHandshake(req)) + initWsPipeline(ctx, responseContent, req) + else { + val buf = Unpooled.wrappedBuffer("Incorrect Web Socket handshake".getBytes) + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST, buf) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, buf.readableBytes()) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res).closeIfNeeded(req) + } }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( @@ -236,6 +244,56 @@ class NettyServerHandler[F[_]]( } ) + private def initWsPipeline( + ctx: ChannelHandlerContext, + r: ReactiveWebSocketProcessorNettyResponseContent, + handshakeReq: HttpRequest + ) = { + ctx.pipeline().remove(this) + ctx.pipeline().remove(classOf[ReadTimeoutHandler]) + ctx + .pipeline() + .addAfter( + ServerCodecHandlerName, + WebSocketControlFrameHandlerName, + new NettyControlFrameHandler( + ignorePong = r.ignorePong, + autoPongOnPing = r.autoPongOnPing, + decodeCloseRequests = r.decodeCloseRequests + ) + ) + r.autoPing.foreach { case (interval, pingMsg) => + ctx + .pipeline() + .addAfter( + WebSocketControlFrameHandlerName, + WebSocketAutoPingHandlerName, + new WebSocketAutoPingHandler(interval, pingMsg) + ) + } + // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly + r.channelPromise.setSuccess() + val _ = ctx.writeAndFlush( + // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler + // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) + new DefaultWebSocketHttpResponse( + handshakeReq.protocolVersion(), + HttpResponseStatus.valueOf(200), + r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler + new WebSocketServerHandshakerFactory(wsUrl(handshakeReq), null, false) + ) + ) + } + + private def isWsHandshake(req: HttpRequest): Boolean = + "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && + "Upgrade".equalsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION)) + + // Only ancient WS protocol versions will use this in the response header. + private def wsUrl(req: HttpRequest): String = { + val scheme = if (isSsl) "wss" else "ws" + s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" + } private implicit class RichServerNettyResponse(r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, @@ -243,7 +301,7 @@ class NettyServerHandler[F[_]]( chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, - wsHandler: ChannelPromise => Unit, + wsHandler: ReactiveWebSocketProcessorNettyResponseContent => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -255,7 +313,7 @@ class NettyServerHandler[F[_]]( case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) - case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r.channelPromise) + case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r) } } case None => noBodyHandler() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala deleted file mode 100644 index ef1cf87a34..0000000000 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/ReactiveWebSocketHandler.scala +++ /dev/null @@ -1,213 +0,0 @@ -package sttp.tapir.server.netty.internal.ws - -import io.netty.buffer.Unpooled -import io.netty.channel.group.ChannelGroup -import io.netty.channel.{ChannelFuture, ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.handler.codec.http._ -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory -import io.netty.handler.timeout.ReadTimeoutHandler -import io.netty.util.ReferenceCountUtil -import org.playframework.netty.http.DefaultWebSocketHttpResponse -import org.slf4j.LoggerFactory -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent -import sttp.tapir.server.netty.internal._ -import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.control.NonFatal -import scala.util.{Failure, Success} -import sttp.tapir.server.netty.NettyResponseContent.ByteBufNettyResponseContent -import io.netty.channel.ChannelPromise -import io.netty.buffer.ByteBuf -import java.util.concurrent.atomic.AtomicBoolean - -/** Handles a WS handshake and initiates the communication by calling Tapir interpreter to get a Pipe, then sends that Pipe to the rest of - * the processing pipeline and removes itself from the pipeline. - */ -class ReactiveWebSocketHandler[F[_]]( - route: Route[F], - channelGroup: ChannelGroup, - unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), - isSsl: Boolean, - isShuttingDown: AtomicBoolean, - serverHeader: Option[String] -)(implicit m: MonadError[F]) - extends ChannelInboundHandlerAdapter { - - // By using the Netty event loop assigned to this channel we get two benefits: - // 1. We can avoid the necessary hopping around of threads since Netty pipelines will - // only pass events up and down from within the event loop to which it is assigned. - // That means calls to ctx.read(), and ctx.write(..), would have to be trampolined otherwise. - // 2. We get serialization of execution: the EventLoop is a serial execution queue so - // we can rest easy knowing that no two events will be executed in parallel. - private[this] var eventLoopContext: ExecutionContext = _ - - private val logger = LoggerFactory.getLogger(getClass.getName) - private val WebSocketAutoPingHandlerName = "wsAutoPingHandler" - - def isWsHandshake(req: HttpRequest): Boolean = - "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && - "Upgrade".equalsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION)) - - override def handlerAdded(ctx: ChannelHandlerContext): Unit = - if (ctx.channel.isActive) { - initHandler(ctx) - } - override def channelActive(ctx: ChannelHandlerContext): Unit = { - channelGroup.add(ctx.channel) - initHandler(ctx) - } - - private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { - if (eventLoopContext == null) - eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logger.error("Error while processing the request", cause) - } - - override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { - def replyWithError500(reason: Throwable): Unit = { - logger.error("Error while processing the request", reason) - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res).close() - } - - def replyWith503(): Unit = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res).close() - } - - def replyNotFound(): Unit = { - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND) - res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - val _ = ctx.writeAndFlush(res).close() - } - - def replyWithRouteResponse( - req: HttpRequest, - serverResponse: ServerResponse[NettyResponse], - channelPromise: ChannelPromise, - byteBuf: ByteBuf - ): Unit = { - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - res.setHeadersFrom(serverResponse, serverHeader) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, byteBuf.readableBytes()) - val _ = ctx.writeAndFlush(res, channelPromise) - } - - msg match { - case req: FullHttpRequest if isWsHandshake(req) => - if (isShuttingDown.get()) { - logger.info("Rejecting WS handshake request because the server is shutting down.") - replyWith503() - } else { - ReferenceCountUtil.release(msg) - val (runningFuture, _) = unsafeRunAsync { () => - route(NettyServerRequest(req.retain())) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - } - - runningFuture.onComplete { - case Success(serverResponse) if serverResponse == ServerResponse.notFound => - replyNotFound() - case Success(serverResponse) => - try { - serverResponse.body match { - case Some(function) => { - val content = function(ctx) - content match { - case r: ReactiveWebSocketProcessorNettyResponseContent => - initWsPipeline(ctx, r, req) - case ByteBufNettyResponseContent(channelPromise, byteBuf) => - // Handshake didn't return a Pipe, but a regular response. Returning it back. - replyWithRouteResponse(req, serverResponse, channelPromise, byteBuf) - case otherContent => - // An unsupported type of regular response, returning 500 - replyWithError500( - new IllegalArgumentException(s"Unsupported response type for a WS endpoint: ${otherContent.getClass.getName}") - ) - } - } - case None => - // Handshake didn't return a Pipe, but a regular response with empty body. Returning it back. - replyWithRouteResponse(req, serverResponse, ctx.newPromise(), Unpooled.EMPTY_BUFFER) - } - } catch { - case NonFatal(ex) => - replyWithError500(ex) - } finally { - val _ = req.release() - } - case Failure(ex) => - try { - replyWithError500(ex) - } finally { - val _ = req.release() - } - }(eventLoopContext) - } - case other => - // not a WS handshake - val _ = ctx.fireChannelRead(other) - } - } - - private def initWsPipeline( - ctx: ChannelHandlerContext, - r: ReactiveWebSocketProcessorNettyResponseContent, - handshakeReq: FullHttpRequest - ) = { - ctx.pipeline().remove(this) - ctx.pipeline().remove(classOf[ReadTimeoutHandler]) - ctx - .pipeline() - .addAfter( - ServerCodecHandlerName, - WebSocketControlFrameHandlerName, - new NettyControlFrameHandler( - ignorePong = r.ignorePong, - autoPongOnPing = r.autoPongOnPing, - decodeCloseRequests = r.decodeCloseRequests - ) - ) - r.autoPing.foreach { case (interval, pingMsg) => - ctx - .pipeline() - .addAfter( - WebSocketControlFrameHandlerName, - WebSocketAutoPingHandlerName, - new WebSocketAutoPingHandler(interval, pingMsg) - ) - } - // Manually completing the promise, for some reason it won't be completed in writeAndFlush. We need its completion for NettyBodyListener to call back properly - r.channelPromise.setSuccess() - val _ = ctx.writeAndFlush( - // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler - // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) - new DefaultWebSocketHttpResponse( - handshakeReq.protocolVersion(), - HttpResponseStatus.valueOf(200), - r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler - new WebSocketServerHandshakerFactory(wsUrl(handshakeReq), null, false) - ) - ) - } - - // Only ancient WS protocol versions will use this in the response header. - private def wsUrl(req: FullHttpRequest): String = { - val scheme = if (isSsl) "wss" else "ws" - s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" - } -} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 093b3bd966..390899e21a 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -86,14 +86,13 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: NettyBootstrap[RIO[R, *]]( config, - List( - new NettyServerHandler[RIO[R, *]]( - route, - unsafeRunAsync(runtime), - channelGroup, - isShuttingDown, - config.serverHeader - ) + new NettyServerHandler[RIO[R, *]]( + route, + unsafeRunAsync(runtime), + channelGroup, + isShuttingDown, + config.serverHeader, + config.isSsl ), eventLoopGroup, socketOverride From c19674251d47814153c32ba0b97aaa900ab659fc Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 29 Mar 2024 15:27:31 +0100 Subject: [PATCH 31/31] Remove unneeded handler name --- .../src/main/scala/sttp/tapir/server/netty/NettyConfig.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index e371851211..d098a6874c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -108,7 +108,6 @@ case class NettyConfig( } object NettyConfig { - val StreamsHandlerName = "streamsHandler" def default: NettyConfig = NettyConfig( host = "localhost", @@ -132,7 +131,7 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(ServerCodecHandlerName, new HttpServerCodec()) - pipeline.addLast(StreamsHandlerName, new HttpStreamsServerHandler()) + pipeline.addLast(new HttpStreamsServerHandler()) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) ()