diff --git a/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala b/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala new file mode 100644 index 0000000000..50729f0c2e --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/WebSocketsNettyCatsServer.scala @@ -0,0 +1,75 @@ +package sttp.tapir.examples + +import cats.effect.{IO, IOApp} +import sttp.client3._ +import sttp.model.StatusCode +import sttp.tapir.server.netty.cats.NettyCatsServer +import sttp.tapir.* +import scala.concurrent.duration._ +import sttp.capabilities.fs2.Fs2Streams +import sttp.ws.WebSocket +import sttp.client3.pekkohttp.PekkoHttpBackend +import scala.concurrent.Future + +object WebSocketsNettyCatsServer extends IOApp.Simple { + // One endpoint on GET /hello with query parameter `name` + val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] = + endpoint.get.in("hello").in(query[String]("name")).out(stringBody) + + val wsEndpoint = + endpoint.get.in("ws").out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])) + + val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => + IO.pure(in => in.evalMap(str => IO.println(s"responding with ${str.toUpperCase}") >> IO.pure(str.toUpperCase()))) + ) + // Just returning passed name with `Hello, ` prepended + val helloWorldServerEndpoint = helloWorldEndpoint + .serverLogic(name => IO.pure[Either[Unit, String]](Right(s"Hello, $name!"))) + + private val declaredPort = 9090 + private val declaredHost = "localhost" + + // Creating handler for netty bootstrap + override def run = NettyCatsServer + .io() + .use { server => + for { + binding <- server + .port(declaredPort) + .host(declaredHost) + .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) + .start() + result <- IO + .fromFuture(IO.delay { + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") + import scala.concurrent.ExecutionContext.Implicits.global + def useWebSocket(ws: WebSocket[Future]): Future[Unit] = { + def send(i: Int) = ws.sendText(s"Hello $i!") + def receive() = ws.receiveText().map(t => println(s"Client RECEIVED: $t")) + for { + _ <- send(1) + _ <- receive() + _ <- send(2) + _ <- send(3) + _ <- receive() + } yield () + } + val backend = PekkoHttpBackend() + + val url = uri"ws://$host:$port/ws" + val allGood = uri"http://$host:$port/hello?name=Netty" + basicRequest.response(asStringAlways).get(allGood).send(backend).map(r => println(r.body)) + .flatMap { _ => + basicRequest + .response(asWebSocket(useWebSocket)) + .get(url) + .send(backend) + } + .andThen { case _ => backend.close() } + }) + .guarantee(binding.stop()) + } yield result + } +} diff --git a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala index 85d3ac7c2b..10639eb7a9 100644 --- a/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala +++ b/perf-tests/src/main/scala/sttp/tapir/perf/netty/cats/NettyCats.scala @@ -3,16 +3,36 @@ package sttp.tapir.perf.netty.cats import cats.effect.IO import cats.effect.kernel.Resource import cats.effect.std.Dispatcher +import fs2.Stream +import sttp.tapir.{CodecFormat, webSocketBody} import sttp.tapir.perf.Common._ import sttp.tapir.perf.apis._ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.cats.NettyCatsServer import sttp.tapir.server.netty.cats.NettyCatsServerOptions +import sttp.ws.WebSocketFrame +import sttp.capabilities.fs2.Fs2Streams -object Tapir extends Endpoints +import scala.concurrent.duration._ -object NettyCats { +object Tapir extends Endpoints { + val wsResponseStream = Stream.fixedRate[IO](WebSocketSingleResponseLag, dampen = false) + val wsEndpoint = wsBaseEndpoint + .out( + webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](Fs2Streams[IO]) + .concatenateFragmentedFrames(false) + .autoPongOnPing(false) + .ignorePong(true) + .autoPing(None) + ) +} +object NettyCats { + val wsServerEndpoint = Tapir.wsEndpoint.serverLogicSuccess(_ => + IO.pure { (in: Stream[IO, Long]) => + Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(())) + } + ) def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = { val declaredPort = Port val declaredHost = "0.0.0.0" @@ -25,7 +45,7 @@ object NettyCats { server .port(declaredPort) .host(declaredHost) - .addEndpoints(endpoints) + .addEndpoints(wsServerEndpoint :: endpoints) .start() )(binding => binding.stop()) } yield ()).allocated.map(_._2) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 1b7a449326..abe2b64be4 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -23,12 +23,14 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future import scala.concurrent.duration._ +import sttp.capabilities.WebSockets +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( @@ -75,6 +77,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty NettyBootstrap( config, new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 9e1c4722c0..8f65e39a73 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -14,12 +14,13 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} +import sttp.capabilities.WebSockets trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] def nettyServerOptions: NettyCatsServerOptions[F] - def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = { + def toRoute(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): Route[F] = { implicit val monad: MonadError[F] = new CatsMonadError[F] val runAsync = new RunAsync[F] { @@ -31,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val createFile = nettyServerOptions.createFile val deleteFile = nettyServerOptions.deleteFile - val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( + val serverInterpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala index 3849fb4a19..28167f4422 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/Fs2StreamCompatible.scala @@ -1,22 +1,19 @@ package sttp.tapir.server.netty.cats.internal +import cats.effect.kernel.{Async, Sync} +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.io.file.{Files, Flags, Path} +import fs2.{Chunk, Pipe} import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher -import sttp.tapir.FileRange +import org.reactivestreams.{Processor, Publisher} +import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream -import cats.effect.std.Dispatcher -import sttp.capabilities.fs2.Fs2Streams -import fs2.io.file.Path -import fs2.io.file.Files -import cats.effect.kernel.Async -import fs2.io.file.Flags -import fs2.interop.reactivestreams.StreamUnicastPublisher -import cats.effect.kernel.Sync -import fs2.Chunk -import fs2.interop.reactivestreams.StreamSubscriber object Fs2StreamCompatible { @@ -68,6 +65,12 @@ object Fs2StreamCompatible { override def emptyStream: streams.BinaryStream = fs2.Stream.empty + override def asWsProcessor[REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] + ): Processor[WebSocketFrame, WebSocketFrame] = + new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o) + private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) = fs2.io.readInputStream( Sync[F].blocking(inputStream()), diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index e81f66fd00..f801d29ec2 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -12,6 +12,7 @@ import sttp.tapir.TapirFile import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} +import sttp.capabilities.WebSockets private[cats] class NettyCatsRequestBody[F[_]: Async]( val createFile: ServerRequest => F[TapirFile], @@ -24,7 +25,8 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] = - (toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream]) + (toStream(serverRequest, maxBytes) + .asInstanceOf[streamCompatible.streams.BinaryStream]) .through( Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath)) ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala new file mode 100644 index 0000000000..518db27e65 --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -0,0 +1,94 @@ +package sttp.tapir.server.netty.cats.internal + +import cats.Applicative +import cats.effect.kernel.Async +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher} +import fs2.{Pipe, Stream} +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription} +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.tapir.server.netty.internal.WebSocketFrameConverters._ +import sttp.tapir.{DecodeResult, WebSocketBodyOutput} +import sttp.ws.WebSocketFrame + +import scala.concurrent.ExecutionContext.Implicits +import scala.concurrent.Promise +import scala.util.{Failure, Success} + +class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( + pipe: Pipe[F, REQ, RESP], + dispatcher: Dispatcher[F], + o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]] +) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] { + private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _ + private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]() + private var subscription: Subscription = _ + + override def onSubscribe(s: Subscription): Unit = { + subscriber = dispatcher.unsafeRunSync( + StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1) + ) + subscription = s + val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit) + val sttpFrames = in.map { f => + val sttpFrame = nettyFrameToFrame(f) + f.release() + sttpFrame + } + val stream: Stream[F, NettyWebSocketFrame] = + optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + .map(f => + o.requests.decode(f) match { + case x: DecodeResult.Value[REQ] => x.v + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + } + ) + .through(pipe) + .map(r => frameToNettyFrame(o.responses.encode(r))) + .append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close))) + + subscriber.sub.onSubscribe(s) + publisher.success(StreamUnicastPublisher(stream, dispatcher)) + } + + override def onNext(t: NettyWebSocketFrame): Unit = { + subscriber.sub.onNext(t) + } + + override def onError(t: Throwable): Unit = { + subscriber.sub.onError(t) + } + + override def onComplete(): Unit = { + subscriber.sub.onComplete() + } + + override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = { + publisher.future.onComplete { + case Success(p) => + p.subscribe(s) + case Failure(ex) => + subscriber.sub.onError(ex) + subscription.cancel + }(Implicits.global) + } + + private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = + if (doConcatenate) { + type Accumulator = Option[Either[Array[Byte], String]] + + s.mapAccumulate(None: Accumulator) { + case (None, f: WebSocketFrame.Ping) => (None, Some(f)) + case (None, f: WebSocketFrame.Pong) => (None, Some(f)) + case (None, f: WebSocketFrame.Close) => (None, Some(f)) + case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + }.collect { case (_, Some(f)) => f } + } else s +} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index e7a1280719..e3252a25a1 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -40,7 +40,11 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++ new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ - new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() + new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ + new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) { + override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) + override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty + }.tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 61be7e6f4d..9334504a2f 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -4,6 +4,7 @@ import cats.data.NonEmptyList import cats.effect.std.Dispatcher import cats.effect.{IO, Resource} import io.netty.channel.nio.NioEventLoopGroup +import sttp.capabilities.WebSockets import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter @@ -12,8 +13,8 @@ import sttp.capabilities.fs2.Fs2Streams import scala.concurrent.duration.FiniteDuration class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) - extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { - override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = { + extends TestServerInterpreter[IO, Fs2Streams[IO] with WebSockets, NettyCatsServerOptions[IO], Route[IO]] { + override def route(es: List[ServerEndpoint[Fs2Streams[IO] with WebSockets, IO]], interceptors: Interceptors): Route[IO] = { val serverOptions: NettyCatsServerOptions[IO] = interceptors( NettyCatsServerOptions.customiseInterceptors[IO](dispatcher) ).options diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala index 46918fc8fd..2de073f763 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyIdServer.scala @@ -25,6 +25,7 @@ import scala.concurrent.Future import scala.concurrent.Promise import scala.concurrent.duration.FiniteDuration import scala.util.control.NonFatal +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) { private val executor = Executors.newVirtualThreadPerTaskExecutor() @@ -100,6 +101,7 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, isShuttingDown, config.serverHeader ), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunF, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 07e43954c5..1d6d66144d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -61,7 +61,7 @@ case class NettyConfig( sslContext: Option[SslContext], eventLoopConfig: EventLoopConfig, socketConfig: NettySocketConfig, - initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit, + initPipeline: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit, gracefulShutdownTimeout: Option[FiniteDuration], serverHeader: Option[String] ) { @@ -96,7 +96,7 @@ case class NettyConfig( def eventLoopConfig(elc: EventLoopConfig): NettyConfig = copy(eventLoopConfig = elc) def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) - def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) + def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) def withGracefulShutdownTimeout(t: FiniteDuration) = copy(gracefulShutdownTimeout = Some(t)) def noGracefulShutdown = copy(gracefulShutdownTimeout = None) @@ -120,14 +120,15 @@ object NettyConfig { sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipeline(cfg)(_, _), + initPipeline = cfg => defaultInitPipeline(cfg)(_, _, _), serverHeader = Some(s"tapir/${buildinfo.BuildInfo.version}") ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler, wsHandler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpStreamsServerHandler()) + pipeline.addLast("serverCodecHandler", new HttpServerCodec()) + pipeline.addLast("streamsHandler", new HttpStreamsServerHandler()) + pipeline.addLast("wsHandler", wsHandler) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index f510e8e85b..2c972c3226 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -16,6 +16,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.FiniteDuration import scala.concurrent.{ExecutionContext, Future, blocking} +import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureServerOptions, config: NettyConfig)(implicit ec: ExecutionContext @@ -72,6 +73,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe NettyBootstrap( config, new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader), + new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index b17335f60b..d3aea1dda6 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -5,6 +5,10 @@ import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.HttpContent import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.reactivestreams.Publisher +import org.reactivestreams.Processor +import io.netty.handler.codec.http.websocketx.WebSocketFrame +import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} +import scala.concurrent.duration.FiniteDuration sealed trait NettyResponseContent { def channelPromise: ChannelPromise @@ -17,4 +21,12 @@ object NettyResponseContent { final case class ChunkedFileNettyResponseContent(channelPromise: ChannelPromise, chunkedFile: ChunkedFile) extends NettyResponseContent final case class ReactivePublisherNettyResponseContent(channelPromise: ChannelPromise, publisher: Publisher[HttpContent]) extends NettyResponseContent + final case class ReactiveWebSocketProcessorNettyResponseContent( + channelPromise: ChannelPromise, + processor: Processor[WebSocketFrame, WebSocketFrame], + ignorePong: Boolean, + autoPongOnPing: Boolean, + decodeCloseRequests: Boolean, + autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] + ) extends NettyResponseContent } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala index 71d57519a8..68835c8f07 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyBootstrap.scala @@ -12,6 +12,7 @@ object NettyBootstrap { def apply[F[_]]( nettyConfig: NettyConfig, handler: => NettyServerHandler[F], + wsHandler: => ReactiveWebSocketHandler[F], eventLoopGroup: EventLoopGroup, overrideSocketAddress: Option[SocketAddress] ): ChannelFuture = { @@ -27,8 +28,12 @@ object NettyBootstrap { nettyConfig.requestTimeout match { case Some(requestTimeout) => - nettyConfigBuilder(ch.pipeline().addLast(new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), handler) - case None => nettyConfigBuilder(ch.pipeline(), handler) + nettyConfigBuilder( + ch.pipeline().addLast("readTimeoutHandler", new ReadTimeoutHandler(requestTimeout.toSeconds.toInt)), + handler, + wsHandler + ) + case None => nettyConfigBuilder(ch.pipeline(), handler, wsHandler) } connectionCounterOpt.map(counter => { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index c3333a7ff4..3da92d8f56 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,6 +4,7 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.channel.group.ChannelGroup import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.playframework.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import org.reactivestreams.Publisher @@ -25,6 +26,8 @@ import scala.collection.mutable.{Queue => MutableQueue} import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.util.{Failure, Success} +import sttp.tapir.EndpointInput.AuthType.Http +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent /** @param unsafeRunAsync * Function which dispatches given effect to run asynchronously, returning its result as a Future, and function of type `() => @@ -211,6 +214,12 @@ class NettyServerHandler[F[_]]( }) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + }, + wsHandler = (channelPromise) => { + logger.error("Unexpected WebSocket processor response received in NettyServerHandler, it should be handled only in the ReactiveWebSocketHandler") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res) }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( @@ -234,6 +243,7 @@ class NettyServerHandler[F[_]]( chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, + wsHandler: ChannelPromise => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -245,6 +255,7 @@ class NettyServerHandler[F[_]]( case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ReactiveWebSocketProcessorNettyResponseContent => wsHandler(r.channelPromise) } } case None => noBodyHandler() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala index 0f335a7b14..ae94a180b3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToStreamsResponseBody.scala @@ -12,12 +12,14 @@ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.nio.ByteBuffer import java.nio.charset.Charset +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent -/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries - * like fs2 or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. - * Other kinds of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. +/** Common logic for producing response body in all Netty backends that support streaming. These backends use streaming libraries like fs2 + * or zio-streams to obtain reactive Publishers representing responses like InputStreamBody, InputStreamRangeBody or FileBody. Other kinds + * of raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher. */ -private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) extends ToResponseBody[NettyResponse, S] { +private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatible: StreamCompatible[S]) + extends ToResponseBody[NettyResponse, S] { override val streams: S = streamCompatible.streams @@ -37,7 +39,10 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl case RawBodyType.InputStreamBody => (ctx: ChannelHandlerContext) => - new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None)) + new ReactivePublisherNettyResponseContent( + ctx.newPromise(), + streamCompatible.publisherFromInputStream(() => v, DefaultChunkSize, length = None) + ) case RawBodyType.InputStreamRangeBody => (ctx: ChannelHandlerContext) => @@ -47,7 +52,8 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl ) case RawBodyType.FileBody => - (ctx: ChannelHandlerContext) => new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) + (ctx: ChannelHandlerContext) => + new ReactivePublisherNettyResponseContent(ctx.newPromise(), streamCompatible.publisherFromFile(v, DefaultChunkSize)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } @@ -68,5 +74,17 @@ private[netty] class NettyToStreamsResponseBody[S <: Streams[S]](streamCompatibl override def fromWebSocketPipe[REQ, RESP]( pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S] - ): NettyResponse = throw new UnsupportedOperationException + ): NettyResponse = (ctx: ChannelHandlerContext) => { + new ReactiveWebSocketProcessorNettyResponseContent( + ctx.newPromise(), + streamCompatible.asWsProcessor( + pipe.asInstanceOf[streamCompatible.streams.Pipe[REQ, RESP]], + o.asInstanceOf[WebSocketBodyOutput[streamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]] + ), + ignorePong = o.ignorePong, + autoPongOnPing = o.autoPongOnPing, + decodeCloseRequests = o.decodeCloseRequests, + autoPing = o.autoPing + ) + } } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala new file mode 100644 index 0000000000..7fc695ab78 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ReactiveWebSocketHandler.scala @@ -0,0 +1,161 @@ +package sttp.tapir.server.netty.internal + +import io.netty.channel.group.ChannelGroup +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory +import io.netty.handler.codec.http._ +import io.netty.util.ReferenceCountUtil +import org.playframework.netty.http.DefaultWebSocketHttpResponse +import org.slf4j.LoggerFactory +import sttp.monad.MonadError +import sttp.monad.syntax._ +import sttp.tapir.server.model.ServerResponse +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal +import scala.util.{Failure, Success} + +/** Handles a WS handshake and initiates the communication by calling Tapir interpreter to get a Pipe, then sends that Pipe to the rest of + * the processing pipeline and removes itself from the pipeline. + */ +class ReactiveWebSocketHandler[F[_]]( + route: Route[F], + channelGroup: ChannelGroup, + unsafeRunAsync: (() => F[ServerResponse[NettyResponse]]) => (Future[ServerResponse[NettyResponse]], () => Future[Unit]), + isSsl: Boolean +)(implicit m: MonadError[F]) + extends ChannelInboundHandlerAdapter { + + // By using the Netty event loop assigned to this channel we get two benefits: + // 1. We can avoid the necessary hopping around of threads since Netty pipelines will + // only pass events up and down from within the event loop to which it is assigned. + // That means calls to ctx.read(), and ctx.write(..), would have to be trampolined otherwise. + // 2. We get serialization of execution: the EventLoop is a serial execution queue so + // we can rest easy knowing that no two events will be executed in parallel. + private[this] var eventLoopContext: ExecutionContext = _ + + private val logger = LoggerFactory.getLogger(getClass.getName) + + def isWsHandshake(req: HttpRequest): Boolean = + "Websocket".equalsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE)) && + "Upgrade".equalsIgnoreCase(req.headers().get(HttpHeaderNames.CONNECTION)) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = + if (ctx.channel.isActive) { + initHandler(ctx) + } + override def channelActive(ctx: ChannelHandlerContext): Unit = { + channelGroup.add(ctx.channel) + initHandler(ctx) + } + + private[this] def initHandler(ctx: ChannelHandlerContext): Unit = { + if (eventLoopContext == null) + eventLoopContext = ExecutionContext.fromExecutor(ctx.channel.eventLoop) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logger.error("Error while processing the request", cause) + } + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + def writeError500(req: HttpRequest, reason: Throwable): Unit = { + logger.error("Error while processing the request", reason) + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + val _ = ctx.writeAndFlush(res) + } + msg match { + case req: FullHttpRequest if isWsHandshake(req) => + ctx.pipeline().remove(this) + ctx.pipeline().remove("readTimeoutHandler") + ReferenceCountUtil.release(msg) + val (runningFuture, _) = unsafeRunAsync { () => + route(NettyServerRequest(req.retain())) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + } + + val _ = runningFuture.transform { + case Success(serverResponse) => + try { + serverResponse.body match { + case Some(function) => { + val content = function(ctx) + content match { + case r: ReactiveWebSocketProcessorNettyResponseContent => { + ctx + .pipeline() + .addAfter( + "serverCodecHandler", + "wsControlFrameHandler", + new NettyControlFrameHandler( + ignorePong = r.ignorePong, + autoPongOnPing = r.autoPongOnPing, + decodeCloseRequests = r.decodeCloseRequests + ) + ) + r.autoPing.foreach { case (interval, pingMsg) => + ctx.pipeline().addFirst("wsAutoPingHandler", new WebSocketAutoPingHandler(interval, pingMsg)) + } + r.channelPromise.setSuccess() + val _ = ctx.writeAndFlush( + // Push a special message down the pipeline, it will be handled by HttpStreamsServerHandler + // and from now on that handler will take control of the flow (our NettyServerHandler will not receive messages) + new DefaultWebSocketHttpResponse( + req.protocolVersion(), + HttpResponseStatus.valueOf(200), + r.processor, // the Processor (Pipe) created by Tapir interpreter will be used by HttpStreamsServerHandler + new WebSocketServerHandshakerFactory(wsUrl(req), null, false) + ) + ) + } + case otherContent => + logger.error(s"Unexpected response content: $otherContent") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + otherContent.channelPromise.setFailure(new IllegalStateException("Unexpected response content")) + val _ = ctx.writeAndFlush(res) + } + } + case None => + logger.error("Missing response body, expected WebSocketProcessorNettyResponseContent") + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + val _ = ctx.writeAndFlush(res) + } + Success(()) + } catch { + case NonFatal(ex) => + writeError500(req, ex) + Failure(ex) + } finally { + val _ = req.release() + } + case Failure(NonFatal(ex)) => + try { + writeError500(req, ex) + Failure(ex) + } finally { + val _ = req.release() + } + case Failure(fatalException) => Failure(fatalException) + }(eventLoopContext) + + case other => + // not a WS handshake, from now on process messages as normal HTTP requests in this channel + ctx.pipeline.remove(this) + val _ = ctx.fireChannelRead(other) + } + } + + // Only ancient WS protocol versions will use this in the response header. + private def wsUrl(req: FullHttpRequest): String = { + val scheme = if (isSsl) "wss" else "ws" + s"$scheme://${req.headers().get(HttpHeaderNames.HOST)}${req.uri()}" + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala index 6d5da177bd..b973aef7eb 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/StreamCompatible.scala @@ -1,9 +1,10 @@ package sttp.tapir.server.netty.internal import io.netty.handler.codec.http.HttpContent -import org.reactivestreams.Publisher +import io.netty.handler.codec.http.websocketx.WebSocketFrame +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.Streams -import sttp.tapir.FileRange +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream @@ -27,4 +28,6 @@ private[netty] trait StreamCompatible[S <: Streams[S]] { def publisherFromInputStream(is: () => InputStream, chunkSize: Int, length: Option[Long]): Publisher[HttpContent] = asPublisher(fromInputStream(is, chunkSize, length)) + + def asWsProcessor[REQ, RESP](pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S]): Processor[WebSocketFrame, WebSocketFrame] } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala new file mode 100644 index 0000000000..62addee363 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketAutoPingHandler.scala @@ -0,0 +1,45 @@ +package sttp.tapir.server.netty.internal + +import io.netty.buffer.Unpooled +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame +import org.slf4j.LoggerFactory + +import java.util.concurrent.{ScheduledFuture, TimeUnit} +import scala.concurrent.duration.FiniteDuration + +/** If auto ping is enabled for an endpoint, this handler will be plugged into the pipeline. Its responsibility is to manage start and stop + * of the ping scheduler. + * TODO: should we include logic for closing the channel and reporting an error if some kind of ping timeout is exceeded? + * @param pingInterval + * time interval to be used between sending pings to the client. + * @param frame + * desired content of the Ping frame, as specified in the Tapir endpoint output. + */ +class WebSocketAutoPingHandler(pingInterval: FiniteDuration, frame: sttp.ws.WebSocketFrame.Ping) extends ChannelInboundHandlerAdapter { + val nettyFrame = new PingWebSocketFrame(Unpooled.copiedBuffer(frame.payload)) + private var pingTask: ScheduledFuture[_] = _ + + private val logger = LoggerFactory.getLogger(getClass.getName) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = { + if (ctx.channel.isActive) { + logger.debug(s"STARTING WebSocket Ping scheduler for channel ${ctx.channel}, interval = $pingInterval") + val sendPing: Runnable = new Runnable { + override def run(): Unit = { + logger.trace(s"Sending PING WebSocket frame for channel ${ctx.channel}") + val _ = ctx.writeAndFlush(nettyFrame.retain()) + } + } + pingTask = ctx.channel().eventLoop().scheduleAtFixedRate(sendPing, pingInterval.toMillis, pingInterval.toMillis, TimeUnit.MILLISECONDS) + } + } + + override def channelInactive(ctx: ChannelHandlerContext): Unit = { + super.channelInactive(ctx) + logger.debug(s"STOPPING WebSocket Ping scheduler for channel ${ctx.channel}") + if (pingTask != null) { + val _ = pingTask.cancel(false) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala new file mode 100644 index 0000000000..098e76bcfe --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketControlFrameHandler.scala @@ -0,0 +1,33 @@ +package sttp.tapir.server.netty.internal + +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame + +/** + * Handles Ping, Pong, and Close frames for WebSockets. + */ +class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean) + extends ChannelInboundHandlerAdapter { + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + msg match { + case ping: PingWebSocketFrame if autoPongOnPing => + val _ = ctx.writeAndFlush(new PongWebSocketFrame(ping.content().retain())) + case pong: PongWebSocketFrame if !ignorePong => + val _ = ctx.fireChannelRead(pong) + case close: CloseWebSocketFrame => + if (decodeCloseRequests) { + // Passing the Close frame for further processing + val _ = ctx.fireChannelRead(close) + } else { + // Responding with Close immediately + val _ = ctx.writeAndFlush(close) + } + case other => + val _ = ctx.fireChannelRead(other) + } + } +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala new file mode 100644 index 0000000000..2d5bc06615 --- /dev/null +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/WebSocketFrameConverters.scala @@ -0,0 +1,32 @@ +package sttp.tapir.server.netty.internal + +import io.netty.handler.codec.http.websocketx._ +import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame} +import sttp.ws.WebSocketFrame +import io.netty.buffer.Unpooled + +object WebSocketFrameConverters { + + def nettyFrameToFrame(nettyFrame: NettyWebSocketFrame): WebSocketFrame = { + nettyFrame match { + case text: TextWebSocketFrame => WebSocketFrame.Text(text.text, text.isFinalFragment, Some(text.rsv)) + case close: CloseWebSocketFrame => WebSocketFrame.Close(close.statusCode, close.reasonText) + case ping: PingWebSocketFrame => WebSocketFrame.Ping(ping.content().nioBuffer().array()) + case pong: PongWebSocketFrame => WebSocketFrame.Pong(pong.content().nioBuffer().array()) + case _ => WebSocketFrame.Binary(nettyFrame.content().nioBuffer().array(), nettyFrame.isFinalFragment, Some(nettyFrame.rsv)) + } + } + + def frameToNettyFrame(w: WebSocketFrame): NettyWebSocketFrame = w match { + case WebSocketFrame.Text(payload, finalFragment, rsvOpt) => + new TextWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), payload) + case WebSocketFrame.Close(statusCode, reasonText) => + new CloseWebSocketFrame(statusCode, reasonText) + case WebSocketFrame.Ping(payload) => + new PingWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Pong(payload) => + new PongWebSocketFrame(Unpooled.wrappedBuffer(payload)) + case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) => + new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload)) + } +} diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 8eb00c55bc..da8c9082d3 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -7,7 +7,7 @@ import io.netty.util.concurrent.DefaultEventExecutor import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler, ReactiveWebSocketHandler} import sttp.tapir.server.netty.zio.internal.ZioUtil.{nettyChannelFutureToScala, nettyFutureToScala} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} import sttp.tapir.ztapir.{RIOMonadError, ZServerEndpoint} @@ -93,6 +93,7 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: isShuttingDown, config.serverHeader ), + new ReactiveWebSocketHandler[RIO[R, *]](route, channelGroup, unsafeRunAsync(runtime), config.sslContext.isDefined), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala index 60a5a2b067..14ad954784 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/ZioStreamCompatible.scala @@ -4,11 +4,12 @@ import _root_.zio._ import _root_.zio.interop.reactivestreams._ import _root_.zio.stream.{Stream, ZStream} import io.netty.buffer.Unpooled +import io.netty.handler.codec.http.websocketx.WebSocketFrame import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import org.reactivestreams.Publisher +import org.reactivestreams.{Processor, Publisher} import sttp.capabilities.zio.ZioStreams -import sttp.tapir.FileRange import sttp.tapir.server.netty.internal._ +import sttp.tapir.{FileRange, WebSocketBodyOutput} import java.io.InputStream @@ -64,6 +65,12 @@ private[zio] object ZioStreamCompatible { override def emptyStream: streams.BinaryStream = ZStream.empty + + override def asWsProcessor[REQ, RESP]( + pipe: Stream[Throwable, REQ] => Stream[Throwable, RESP], + o: WebSocketBodyOutput[Stream[Throwable, REQ] => Stream[Throwable, RESP], REQ, RESP, ?, ZioStreams] + ): Processor[WebSocketFrame, WebSocketFrame] = + throw new UnsupportedOperationException("TODO") } } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index 378fd86ddf..05cfaad8bb 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -43,11 +43,13 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( _ <- ws.sendText("test2") m1 <- ws.receiveText() m2 <- ws.receiveText() - } yield List(m1, m2) + _ <- ws.close() + m3 <- ws.eitherClose(ws.receiveText()) + } yield List(m1, m2, m3) }) .get(baseUri.scheme("ws")) .send(backend) - .map(_.body shouldBe Right(List("echo: test1", "echo: test2"))) + .map(_.body shouldBe Right(List("echo: test1", "echo: test2", Left(WebSocketFrame.Close(1000, "normal closure"))))) }, { val reqCounter = newRequestCounter[F]