From 1cf911da0e07ff801df243444a2ad0df1dc47a7a Mon Sep 17 00:00:00 2001 From: Krzysztof Ciesielski Date: Fri, 19 Apr 2024 10:25:43 +0200 Subject: [PATCH] Web Sockets for netty-loom (#3675) --- build.sbt | 19 ++- doc/server/netty.md | 28 ++- .../examples/HelloWorldNettySyncServer.scala | 29 ++++ .../websocket/WebSocketNettySyncServer.scala | 56 ++++++ perf-tests/README.md | 8 +- .../sttp/tapir/perf/apis/Endpoints.scala | 2 - .../tapir/perf/netty/loom/NettySync.scala | 32 ---- project/Versions.scala | 2 + .../server/netty/loom/NettyOxStreams.scala | 11 ++ .../server/netty/loom/NettySyncServer.scala | 157 +++++++++-------- .../loom/NettySyncServerInterpreter.scala | 48 ++++-- .../netty/loom/NettySyncServerOptions.scala | 15 +- .../{ => internal}/NettySyncRequestBody.scala | 14 +- .../internal/NettySyncToResponseBody.scala | 47 +++++ .../netty/loom/internal/ox/OxDispatcher.scala | 31 ++++ .../reactivestreams/ChannelSubscription.scala | 37 ++++ .../reactivestreams/OxProcessor.scala | 79 +++++++++ .../ws/OxSourceWebSocketProcessor.scala | 81 +++++++++ .../sttp/tapir/server/netty/loom/loom.scala | 3 +- .../netty/loom/NettySyncServerTest.scala | 160 +++++++++++++++--- .../loom/NettySyncTestServerInterpreter.scala | 52 +++++- .../loom/perf/NettySyncServerRunner.scala | 98 +++++++++++ 22 files changed, 820 insertions(+), 189 deletions(-) create mode 100644 examples/src/main/scala/sttp/tapir/examples/HelloWorldNettySyncServer.scala create mode 100644 examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala delete mode 100644 perf-tests/src/main/scala/sttp/tapir/perf/netty/loom/NettySync.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyOxStreams.scala rename server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/{ => internal}/NettySyncRequestBody.scala (73%) create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncToResponseBody.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ox/OxDispatcher.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/ChannelSubscription.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/OxProcessor.scala create mode 100644 server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ws/OxSourceWebSocketProcessor.scala create mode 100644 server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/perf/NettySyncServerRunner.scala diff --git a/build.sbt b/build.sbt index 0b8c2e6d75..931129780b 100644 --- a/build.sbt +++ b/build.sbt @@ -249,6 +249,9 @@ lazy val rawAllAggregates = core.projectRefs ++ derevo.projectRefs ++ awsCdk.projectRefs +def buildWithLoom(project: String): Boolean = + project.contains("Loom") || project.contains("nima") || project.contains("perfTests") || project.contains("examples3") + lazy val allAggregates: Seq[ProjectReference] = { val filteredByNative = if (sys.env.isDefinedAt("STTP_NATIVE")) { println("[info] STTP_NATIVE defined, including native in the aggregate projects") @@ -259,13 +262,13 @@ lazy val allAggregates: Seq[ProjectReference] = { } if (sys.env.isDefinedAt("ONLY_LOOM")) { println("[info] ONLY_LOOM defined, including only loom-based projects") - filteredByNative.filter(p => (p.toString.contains("Loom") || p.toString.contains("nima") || p.toString.contains("perfTests"))) + filteredByNative.filter(p => buildWithLoom(p.toString)) } else if (sys.env.isDefinedAt("ALSO_LOOM")) { println("[info] ALSO_LOOM defined, including also loom-based projects") filteredByNative } else { println("[info] ONLY_LOOM *not* defined, *not* including loom-based-projects") - filteredByNative.filterNot(p => (p.toString.contains("Loom") || p.toString.contains("nima") || p.toString.contains("perfTests"))) + filteredByNative.filterNot(p => buildWithLoom(p.toString)) } } @@ -545,7 +548,6 @@ lazy val perfTests: ProjectMatrix = (projectMatrix in file("perf-tests")) http4sServer, nettyServer, nettyServerCats, - nettyServerLoom, playServer, vertxServer, vertxServerCats, @@ -1454,9 +1456,15 @@ lazy val nettyServerLoom: ProjectMatrix = .settings( name := "tapir-netty-server-loom", // needed because of https://github.com/coursier/coursier/issues/2016 - useCoursier := false + useCoursier := false, + Test / run / fork := true, + libraryDependencies ++= Seq( + "com.softwaremill.ox" %% "core" % Versions.ox, + "org.reactivestreams" % "reactive-streams-tck" % Versions.reactiveStreams % Test, + "com.disneystreaming" %% "weaver-cats" % "0.8.4" % Test + ) ) - .jvmPlatform(scalaVersions = scala2_13And3Versions) + .jvmPlatform(scalaVersions = List(scala3)) .dependsOn(nettyServer, serverTests % Test) lazy val nettyServerCats: ProjectMatrix = nettyServerProject("cats", catsEffect) @@ -2141,6 +2149,7 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples")) sttpClient, swaggerUiBundle, http4sServerZio, + nettyServerLoom, nettyServerZio, zioHttpServer, zioJson, diff --git a/doc/server/netty.md b/doc/server/netty.md index c6e7e44adf..7cb1f16d84 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -44,21 +44,11 @@ val binding: Future[NettyFutureServerBinding] = NettyFutureServer().addEndpoint(helloWorld).start() ``` -The `tapir-netty-server-loom` server uses `Id[T]` as its wrapper effect for compatibility, while `Id[A]` means in fact just `A`, representing direct style. +The `tapir-netty-server-loom` server uses `Id[T]` as its wrapper effect for compatibility, while `Id[A]` means in fact just `A`, representing direct style. It is +available only for Scala 3. +See [examples/HelloWorldNettySyncServer.scala](https://github.com/softwaremill/tapir/blob/master/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettySyncServer.scala) for a full example. +To learn more about handling concurrency with Ox, see the [documentation](https://ox.softwaremill.com/). -```scala -import sttp.tapir._ -import sttp.tapir.server.netty.loom.{Id, NettySyncServer, NettySyncServerBinding} - -val helloWorld = endpoint - .get - .in("hello").in(query[String]("name")) - .out(stringBody) - .serverLogicSuccess[Id](name => s"Hello, $name!") - -val binding: NettySyncServerBinding = - NettySyncServer().addEndpoint(helloWorld).start() -``` ## Configuration @@ -85,7 +75,10 @@ NettyFutureServer(NettyConfig.default.socketBacklog(256)) ## Web sockets -The netty-cats interpreter supports web sockets, with pipes of type `fs2.Pipe[F, REQ, RESP]`. See [web sockets](../endpoint/websockets.md) + +### tapir-netty-server-cats + +The Cats Effects interpreter supports web sockets, with pipes of type `fs2.Pipe[F, REQ, RESP]`. See [web sockets](../endpoint/websockets.md) for more details. To create a web socket endpoint, use Tapir's `out(webSocketBody)` output type: @@ -148,6 +141,11 @@ object WebSocketsNettyCatsServer extends ResourceApp.Forever { } ``` +### tapir-netty-server-loom + +In the Loom-based backend, Tapir uses [Ox](https://ox.softwaremill.com) to manage concurrency, and your transformation pipeline should be represented as `Ox ?=> Source[A] => Source[B]`. Any forks started within this function will be run under a safely isolated internal scope. +See [examples/websocket/WebSocketNettySyncServer.scala](https://github.com/softwaremill/tapir/blob/master/examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala) for a full example. + ## Graceful shutdown A Netty server can be gracefully closed using the function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources. diff --git a/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettySyncServer.scala b/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettySyncServer.scala new file mode 100644 index 0000000000..93cb0a1af8 --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/HelloWorldNettySyncServer.scala @@ -0,0 +1,29 @@ +package sttp.tapir.examples + +import ox.* +import sttp.tapir.* +import sttp.tapir.server.netty.loom.{Id, NettySyncServer} + +object HelloWorldNettySyncServer: + val helloWorld = endpoint.get + .in("hello") + .in(query[String]("name")) + .out(stringBody) + .serverLogicSuccess[Id](name => s"Hello, $name!") + + NettySyncServer().addEndpoint(helloWorld).startAndWait() + +// Alternatively, if you need manual control of the structured concurrency scope, server lifecycle, +// or just metadata from `NettySyncServerBinding` (like port number), use `start()`: +object HelloWorldNettySyncServer2: + val helloWorld = endpoint.get + .in("hello") + .in(query[String]("name")) + .out(stringBody) + .serverLogicSuccess[Id](name => s"Hello, $name!") + + supervised { + val serverBinding = useInScope(NettySyncServer().addEndpoint(helloWorld).start())(_.stop()) + println(s"Tapir is running on port ${serverBinding.port}") + never + } diff --git a/examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala b/examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala new file mode 100644 index 0000000000..3814248a64 --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala @@ -0,0 +1,56 @@ +package sttp.tapir.examples.websocket + +import ox.* +import ox.channels.* +import sttp.capabilities.WebSockets +import sttp.tapir.* +import sttp.tapir.server.netty.loom.Id +import sttp.tapir.server.netty.loom.OxStreams +import sttp.tapir.server.netty.loom.OxStreams.Pipe // alias for Ox ?=> Source[A] => Source[B] +import sttp.tapir.server.netty.loom.NettySyncServer +import sttp.ws.WebSocketFrame + +import scala.concurrent.duration.* + +object WebSocketNettySyncServer: + // Web socket endpoint + val wsEndpoint = + endpoint.get + .in("ws") + .out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](OxStreams) + .concatenateFragmentedFrames(false) // All these options are supported by tapir-netty + .ignorePong(true) + .autoPongOnPing(true) + .decodeCloseRequests(false) + .decodeCloseResponses(false) + .autoPing(Some((10.seconds, WebSocketFrame.Ping("ping-content".getBytes)))) + ) + + // Your processor transforming a stream of requests into a stream of responses + val wsPipe: Pipe[String, String] = requestStream => requestStream.map(_.toUpperCase) + // Alternatively, requests and responses can be treated separately, for example to emit frames to the client from another source: + val wsPipe2: Pipe[String, String] = { in => + fork { + in.drain() // read and ignore requests + } + // emit periodic responses + Source.tick(1.second).map(_ => System.currentTimeMillis()).map(_.toString) + } + + // The WebSocket endpoint, builds the pipeline in serverLogicSuccess + val wsServerEndpoint = wsEndpoint.serverLogicSuccess[Id](_ => wsPipe) + + // A regular /GET endpoint + val helloWorldEndpoint = + endpoint.get.in("hello").in(query[String]("name")).out(stringBody) + + val helloWorldServerEndpoint = helloWorldEndpoint + .serverLogicSuccess[Id](name => s"Hello, $name!") + + def main(args: Array[String]): Unit = + NettySyncServer() + .host("0.0.0.0") + .port(8080) + .addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint)) + .startAndWait() diff --git a/perf-tests/README.md b/perf-tests/README.md index d6a5dc2186..8da88766a1 100644 --- a/perf-tests/README.md +++ b/perf-tests/README.md @@ -1,6 +1,6 @@ # Performance tests -To work with performance tests, make sure you are running JDK 21+, and that the `ALSO_LOOM` environment variable is set, because the `perf-tests` project includes `tapir-netty-loom` and `tapir-nima`, which require Loom JDK feature to be available. +To work with performance tests, make sure you are running JDK 21+, and that the `ALSO_LOOM` environment variable is set, because the `perf-tests` project includes `tapir-nima`, which require Loom JDK feature to be available. Performance tests are executed by running `PerfTestSuiteRunner`, which is a standard "Main" Scala application, configured by command line parameters. It executes a sequence of tests, where each test consist of: @@ -122,6 +122,12 @@ For WebSockets we want to measure latency distribution, not throughput, so use g ``` perfTests/runMain sttp.tapir.perf.apis.ServerRunner http4s.Tapir ``` +If you're testing `NettySyncServer` (tapir-server-netty-loom), its server runner is located elsewhere: +``` +nettyServerLoom3/Test/runMain sttp.tapir.netty.loom.perf.NettySyncServerRunner +``` +This is caused by `perf-tests` using Scala 2.13 forced by Gatling, while `NettySyncServer` is written excluisively for Scala 3. + 3. Run the simulation using Gatling's task: ``` perfTests/Gatling/testOnly sttp.tapir.perf.WebSocketsSimulation diff --git a/perf-tests/src/main/scala/sttp/tapir/perf/apis/Endpoints.scala b/perf-tests/src/main/scala/sttp/tapir/perf/apis/Endpoints.scala index 7ae2f65e39..9359682160 100644 --- a/perf-tests/src/main/scala/sttp/tapir/perf/apis/Endpoints.scala +++ b/perf-tests/src/main/scala/sttp/tapir/perf/apis/Endpoints.scala @@ -3,7 +3,6 @@ package sttp.tapir.perf.apis import cats.effect.IO import sttp.tapir._ import sttp.tapir.perf.Common._ -import sttp.tapir.server.netty.loom.Id import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.EndpointExtensions._ @@ -67,5 +66,4 @@ trait Endpoints { def genEndpointsFuture(count: Int): List[ServerEndpoint[Any, Future]] = genServerEndpoints(count)(Future.successful) def genEndpointsIO(count: Int): List[ServerEndpoint[Any, IO]] = genServerEndpoints(count)(IO.pure) - def genEndpointsId(count: Int): List[ServerEndpoint[Any, Id]] = genServerEndpoints[Id](count)(x => x: Id[String]) } diff --git a/perf-tests/src/main/scala/sttp/tapir/perf/netty/loom/NettySync.scala b/perf-tests/src/main/scala/sttp/tapir/perf/netty/loom/NettySync.scala deleted file mode 100644 index 399cb3d629..0000000000 --- a/perf-tests/src/main/scala/sttp/tapir/perf/netty/loom/NettySync.scala +++ /dev/null @@ -1,32 +0,0 @@ -package sttp.tapir.perf.netty.loom - -import cats.effect.IO -import sttp.tapir.perf.apis._ -import sttp.tapir.perf.Common._ -import sttp.tapir.server.netty.loom._ -import sttp.tapir.server.ServerEndpoint - -object Tapir extends Endpoints - -object NettySync { - - def runServer(endpoints: List[ServerEndpoint[Any, Id]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = { - val declaredPort = Port - val declaredHost = "0.0.0.0" - val serverOptions = buildOptions(NettySyncServerOptions.customiseInterceptors, withServerLog) - // Starting netty server - val serverBinding: NettySyncServerBinding = - NettySyncServer(serverOptions) - .port(declaredPort) - .host(declaredHost) - .addEndpoints(endpoints) - .start() - IO(IO(serverBinding.stop())) - } -} - -object TapirServer extends ServerRunner { override def start = NettySync.runServer(Tapir.genEndpointsId(1)) } -object TapirMultiServer extends ServerRunner { override def start = NettySync.runServer(Tapir.genEndpointsId(128)) } -object TapirInterceptorMultiServer extends ServerRunner { - override def start = NettySync.runServer(Tapir.genEndpointsId(128), withServerLog = true) -} diff --git a/project/Versions.scala b/project/Versions.scala index c94f024ce3..5a88444689 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -25,6 +25,8 @@ object Versions { val json4s = "4.0.7" val metrics4Scala = "4.2.9" val nettyReactiveStreams = "3.0.2" + val ox = "0.0.26" + val reactiveStreams = "1.0.4" val sprayJson = "1.3.6" val scalaCheck = "1.17.1" val scalaTest = "3.2.18" diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyOxStreams.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyOxStreams.scala new file mode 100644 index 0000000000..9e0654ee0a --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettyOxStreams.scala @@ -0,0 +1,11 @@ +package sttp.tapir.server.netty.loom + +import ox.Ox +import ox.channels.Source +import sttp.capabilities.Streams + +trait OxStreams extends Streams[OxStreams]: + override type BinaryStream = Nothing + override type Pipe[A, B] = Ox ?=> Source[A] => Source[B] + +object OxStreams extends OxStreams diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServer.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServer.scala index 4418d2d4a8..b94f56f9cc 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServer.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServer.scala @@ -1,43 +1,49 @@ package sttp.tapir.server.netty.loom -import io.netty.channel.Channel -import io.netty.channel.EventLoopGroup +import ox.* +import internal.ox.OxDispatcher import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup} import io.netty.channel.unix.DomainSocketAddress +import io.netty.channel.{Channel, EventLoopGroup} import io.netty.util.concurrent.DefaultEventExecutor +import sttp.capabilities.WebSockets import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.NettyConfig -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.Route -import sttp.tapir.server.netty.internal.NettyBootstrap -import sttp.tapir.server.netty.internal.NettyServerHandler - -import java.net.InetSocketAddress -import java.net.SocketAddress +import sttp.tapir.server.netty.internal.{NettyBootstrap, NettyServerHandler} +import sttp.tapir.server.netty.{NettyConfig, NettyResponse, Route} + +import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.Path -import java.nio.file.Paths -import java.util.UUID -import java.util.concurrent.Executors -import java.util.concurrent.{Future => JFuture} import java.util.concurrent.atomic.AtomicBoolean -import scala.concurrent.Future -import scala.concurrent.Promise +import java.util.concurrent.{Executors, Future => JFuture} import scala.concurrent.duration.FiniteDuration +import scala.concurrent.{Future, Promise} import scala.util.control.NonFatal -case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOptions, config: NettyConfig) { +/** Unlike with most typical Tapir backends, adding endpoints doesn't immediatly convert them to a Route, because creating a Route requires + * providing an Ox concurrency scope. Instead, it stores Endpoints and defers route creation until server.start() is called. This internal + * [[NettySyncServerEndpointListOverriddenOptions]] is an intermediary helper type representing added endpoints, which have custom server + * options. + */ +private[loom] case class NettySyncServerEndpointListOverridenOptions( + ses: List[ServerEndpoint[OxStreams & WebSockets, Id]], + overridenOptions: NettySyncServerOptions +) + +case class NettySyncServer( + endpoints: List[ServerEndpoint[OxStreams & WebSockets, Id]], + endpointsWithOptions: List[NettySyncServerEndpointListOverridenOptions], + options: NettySyncServerOptions, + config: NettyConfig +): private val executor = Executors.newVirtualThreadPerTaskExecutor() - def addEndpoint(se: ServerEndpoint[Any, Id]): NettySyncServer = addEndpoints(List(se)) - def addEndpoint(se: ServerEndpoint[Any, Id], overrideOptions: NettySyncServerOptions): NettySyncServer = + def addEndpoint(se: ServerEndpoint[OxStreams & WebSockets, Id]): NettySyncServer = addEndpoints(List(se)) + def addEndpoint(se: ServerEndpoint[OxStreams & WebSockets, Id], overrideOptions: NettySyncServerOptions): NettySyncServer = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Any, Id]]): NettySyncServer = addRoute(NettySyncServerInterpreter(options).toRoute(ses)) - def addEndpoints(ses: List[ServerEndpoint[Any, Id]], overrideOptions: NettySyncServerOptions): NettySyncServer = - addRoute(NettySyncServerInterpreter(overrideOptions).toRoute(ses)) - - def addRoute(r: IdRoute): NettySyncServer = copy(routes = routes :+ r) - def addRoutes(r: Iterable[IdRoute]): NettySyncServer = copy(routes = routes ++ r) + def addEndpoints(ses: List[ServerEndpoint[OxStreams & WebSockets, Id]]): NettySyncServer = copy(endpoints = endpoints ++ ses) + def addEndpoints(ses: List[ServerEndpoint[OxStreams & WebSockets, Id]], overrideOptions: NettySyncServerOptions): NettySyncServer = + copy(endpointsWithOptions = endpointsWithOptions :+ NettySyncServerEndpointListOverridenOptions(ses, overrideOptions)) def options(o: NettySyncServerOptions): NettySyncServer = copy(options = o) def config(c: NettyConfig): NettySyncServer = copy(config = c) @@ -47,28 +53,57 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti def port(p: Int): NettySyncServer = modifyConfig(_.port(p)) - def start(): NettySyncServerBinding = - startUsingSocketOverride[InetSocketAddress](None) match { + /** Use only if you need to manage server lifecycle or concurrency scope manually. Otherwise, see [[startAndWait]]. + * @example + * {{{ + * import ox.* + * + * supervised { + * val serverBinding = useInScope(server.start())(_.stop()) + * println(s"Tapir is running on port ${serverBinding.port}) + * never + * } + * }}} + * @return + * server binding, to be used to control stopping of the server or obtaining metadata like port. + */ + def start()(using Ox): NettySyncServerBinding = + startUsingSocketOverride[InetSocketAddress](None, new OxDispatcher()) match case (socket, stop) => NettySyncServerBinding(socket, stop) + + /** Starts the server and blocks current virtual thread. Ensures graceful shutdown if the running server gets interrupted. Use [[start]] + * if you need to manually control concurrency scope or server lifecycle. + */ + def startAndWait(): Unit = + supervised { + useInScope(start())(_.stop()).discard + never } - def startUsingDomainSocket(path: Option[Path] = None): NettySyncDomainSocketBinding = - startUsingDomainSocket(path.getOrElse(Paths.get(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString))) + private[netty] def start(routes: List[Route[Id]]): NettySyncServerBinding = + startUsingSocketOverride[InetSocketAddress](routes, None) match + case (socket, stop) => + NettySyncServerBinding(socket, stop) - def startUsingDomainSocket(path: Path): NettySyncDomainSocketBinding = - startUsingSocketOverride(Some(new DomainSocketAddress(path.toFile))) match { + def startUsingDomainSocket(path: Path)(using Ox): NettySyncDomainSocketBinding = + startUsingSocketOverride(Some(new DomainSocketAddress(path.toFile)), new OxDispatcher()) match case (socket, stop) => NettySyncDomainSocketBinding(socket, stop) - } - private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA]): (SA, () => Unit) = { + private def startUsingSocketOverride[SA <: SocketAddress](socketOverride: Option[SA], oxDispatcher: OxDispatcher): (SA, () => Unit) = + val routes = NettySyncServerInterpreter(options).toRoute(endpoints, oxDispatcher) :: endpointsWithOptions.map(e => + NettySyncServerInterpreter(e.overridenOptions).toRoute(e.ses, oxDispatcher) + ) + startUsingSocketOverride(routes, socketOverride) + + private def startUsingSocketOverride[SA <: SocketAddress](routes: List[Route[Id]], socketOverride: Option[SA]): (SA, () => Unit) = val eventLoopGroup = config.eventLoopConfig.initEventLoopGroup() val route = Route.combine(routes) def unsafeRunF( callToExecute: () => Id[ServerResponse[NettyResponse]] - ): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = { + ): (Future[ServerResponse[NettyResponse]], () => Future[Unit]) = val scalaPromise = Promise[ServerResponse[NettyResponse]]() val jFuture: JFuture[?] = executor.submit(new Runnable { override def run(): Unit = try { @@ -86,7 +121,7 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti Future.unit } ) - } + val eventExecutor = new DefaultEventExecutor() val channelGroup = new DefaultChannelGroup(eventExecutor) // thread safe val isShuttingDown: AtomicBoolean = new AtomicBoolean(false) @@ -104,34 +139,26 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti eventLoopGroup, socketOverride ) - try { + try channelIdFuture.sync() val channelId = channelIdFuture.channel() ( channelId.localAddress().asInstanceOf[SA], () => stop(channelId, eventLoopGroup, channelGroup, eventExecutor, isShuttingDown, config.gracefulShutdownTimeout) ) - } catch { + catch case NonFatal(startFailureCause) => - try { - stopRecovering(eventLoopGroup, channelGroup, eventExecutor, isShuttingDown, config.gracefulShutdownTimeout) - } catch { - case NonFatal(recoveryFailureCause) => startFailureCause.addSuppressed(recoveryFailureCause) - } + try stopRecovering(eventLoopGroup, channelGroup, eventExecutor, isShuttingDown, config.gracefulShutdownTimeout) + catch case NonFatal(recoveryFailureCause) => startFailureCause.addSuppressed(recoveryFailureCause) throw startFailureCause - } - } private def waitForClosedChannels( channelGroup: ChannelGroup, startNanos: Long, gracefulShutdownTimeoutNanos: Option[Long] - ): Unit = { - while (!channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos)) { - Thread.sleep(100) - } + ): Unit = + while !channelGroup.isEmpty && gracefulShutdownTimeoutNanos.exists(_ >= System.nanoTime() - startNanos) do Thread.sleep(100) val _ = channelGroup.close().get() - } private def stop( ch: Channel, @@ -140,7 +167,7 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti eventExecutor: DefaultEventExecutor, isShuttingDown: AtomicBoolean, gracefulShutdownTimeout: Option[FiniteDuration] - ): Unit = { + ): Unit = isShuttingDown.set(true) waitForClosedChannels( channelGroup, @@ -148,11 +175,9 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) ) ch.close().get() - if (config.shutdownEventLoopGroupOnClose) { + if config.shutdownEventLoopGroupOnClose then val _ = eventLoopGroup.shutdownGracefully().get() val _ = eventExecutor.shutdownGracefully().get() - } - } private def stopRecovering( eventLoopGroup: EventLoopGroup, @@ -160,36 +185,32 @@ case class NettySyncServer(routes: Vector[IdRoute], options: NettySyncServerOpti eventExecutor: DefaultEventExecutor, isShuttingDown: AtomicBoolean, gracefulShutdownTimeout: Option[FiniteDuration] - ): Unit = { + ): Unit = isShuttingDown.set(true) waitForClosedChannels( channelGroup, startNanos = System.nanoTime(), gracefulShutdownTimeoutNanos = gracefulShutdownTimeout.map(_.toNanos) ) - if (config.shutdownEventLoopGroupOnClose) { + if config.shutdownEventLoopGroupOnClose then val _ = eventLoopGroup.shutdownGracefully().get() val _ = eventExecutor.shutdownGracefully().get() - } - } -} -object NettySyncServer { - def apply(): NettySyncServer = NettySyncServer(Vector.empty, NettySyncServerOptions.default, NettyConfig.default) +object NettySyncServer: + def apply(): NettySyncServer = NettySyncServer(List.empty, List.empty, NettySyncServerOptions.default, NettyConfig.default) def apply(serverOptions: NettySyncServerOptions): NettySyncServer = - NettySyncServer(Vector.empty, serverOptions, NettyConfig.default) + NettySyncServer(List.empty, List.empty, serverOptions, NettyConfig.default) def apply(config: NettyConfig): NettySyncServer = - NettySyncServer(Vector.empty, NettySyncServerOptions.default, config) + NettySyncServer(List.empty, List.empty, NettySyncServerOptions.default, config) def apply(serverOptions: NettySyncServerOptions, config: NettyConfig): NettySyncServer = - NettySyncServer(Vector.empty, serverOptions, config) -} -case class NettySyncServerBinding(localSocket: InetSocketAddress, stop: () => Unit) { + NettySyncServer(List.empty, List.empty, serverOptions, config) + +case class NettySyncServerBinding(localSocket: InetSocketAddress, stop: () => Unit): def hostName: String = localSocket.getHostName def port: Int = localSocket.getPort -} -case class NettySyncDomainSocketBinding(localSocket: DomainSocketAddress, stop: () => Unit) { + +case class NettySyncDomainSocketBinding(localSocket: DomainSocketAddress, stop: () => Unit): def path: String = localSocket.path() -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerInterpreter.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerInterpreter.scala index d09dc679ea..6a257b44c6 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerInterpreter.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerInterpreter.scala @@ -1,29 +1,45 @@ package sttp.tapir.server.netty.loom +import internal.{NettySyncRequestBody, NettySyncToResponseBody} +import internal.ox.OxDispatcher +import sttp.capabilities.WebSockets +import sttp.monad.syntax._ import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.internal.{NettyToResponseBody, NettyServerInterpreter, RunAsync} +import sttp.tapir.server.interceptor.reject.RejectInterceptor +import sttp.tapir.server.interceptor.RequestResult +import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync} +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} -trait NettySyncServerInterpreter { +trait NettySyncServerInterpreter: def nettyServerOptions: NettySyncServerOptions + /** Requires implicit supervision scope (Ox), because it needs to know in which scope it can start background forks in the Web Sockets + * processor. + */ def toRoute( - ses: List[ServerEndpoint[Any, Id]] - ): IdRoute = { - NettyServerInterpreter.toRoute[Id]( - ses, - nettyServerOptions.interceptors, + ses: List[ServerEndpoint[OxStreams & WebSockets, Id]], + oxDispatcher: OxDispatcher + ): IdRoute = + implicit val bodyListener: BodyListener[Id, NettyResponse] = new NettyBodyListener(RunAsync.Id) + val serverInterpreter = new ServerInterpreter[OxStreams with WebSockets, Id, NettyResponse, OxStreams]( + FilterServerEndpoints(ses), new NettySyncRequestBody(nettyServerOptions.createFile), - new NettyToResponseBody[Id](RunAsync.Id), - nettyServerOptions.deleteFile, - RunAsync.Id + new NettySyncToResponseBody(RunAsync.Id, oxDispatcher), + RejectInterceptor.disableWhenSingleEndpoint(nettyServerOptions.interceptors, ses), + nettyServerOptions.deleteFile ) - } -} + val handler: Route[Id] = { (request: NettyServerRequest) => + serverInterpreter(request) + .map { + case RequestResult.Response(response) => Some(response) + case RequestResult.Failure(_) => None + } + } + handler -object NettySyncServerInterpreter { - def apply(serverOptions: NettySyncServerOptions = NettySyncServerOptions.default): NettySyncServerInterpreter = { +object NettySyncServerInterpreter: + def apply(serverOptions: NettySyncServerOptions = NettySyncServerOptions.default): NettySyncServerInterpreter = new NettySyncServerInterpreter { override def nettyServerOptions: NettySyncServerOptions = serverOptions } - } -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerOptions.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerOptions.scala index 02c38b300b..d278bd8064 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerOptions.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncServerOptions.scala @@ -11,14 +11,14 @@ case class NettySyncServerOptions( interceptors: List[Interceptor[Id]], createFile: ServerRequest => TapirFile, deleteFile: TapirFile => Unit -) { +): def prependInterceptor(i: Interceptor[Id]): NettySyncServerOptions = copy(interceptors = i :: interceptors) def appendInterceptor(i: Interceptor[Id]): NettySyncServerOptions = copy(interceptors = interceptors :+ i) -} -object NettySyncServerOptions { +object NettySyncServerOptions: - /** Default options, using TCP sockets (the most common case). This can be later customised using [[NettySyncServerOptions#nettyOptions()]]. + /** Default options, using TCP sockets (the most common case). This can be later customised using + * [[NettySyncServerOptions#nettyOptions()]]. */ def default: NettySyncServerOptions = customiseInterceptors.options @@ -34,15 +34,14 @@ object NettySyncServerOptions { /** Customise the interceptors that are being used when exposing endpoints as a server. By default uses TCP sockets (the most common * case), but this can be later customised using [[NettySyncServerOptions#nettyOptions()]]. */ - def customiseInterceptors: CustomiseInterceptors[Id, NettySyncServerOptions] = { + def customiseInterceptors: CustomiseInterceptors[Id, NettySyncServerOptions] = CustomiseInterceptors( createOptions = (ci: CustomiseInterceptors[Id, NettySyncServerOptions]) => default(ci.interceptors) ).serverLog(defaultServerLog) - } private val log = LoggerFactory.getLogger(getClass.getName) - lazy val defaultServerLog: ServerLog[Id] = { + lazy val defaultServerLog: ServerLog[Id] = DefaultServerLog[Id]( doLogWhenReceived = debugLog(_, None), doLogWhenHandled = debugLog, @@ -50,7 +49,5 @@ object NettySyncServerOptions { doLogExceptions = (msg: String, ex: Throwable) => log.error(msg, ex), noLog = () ) - } private def debugLog(msg: String, exOpt: Option[Throwable]): Unit = NettyDefaults.debugLog(log, msg, exOpt) -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncRequestBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncRequestBody.scala similarity index 73% rename from server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncRequestBody.scala rename to server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncRequestBody.scala index 4817bf175b..1f5be2eb34 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/NettySyncRequestBody.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncRequestBody.scala @@ -1,4 +1,4 @@ -package sttp.tapir.server.netty.loom +package sttp.tapir.server.netty.loom.internal import io.netty.handler.codec.http.HttpContent import org.playframework.netty.http.StreamedHttpRequest @@ -6,25 +6,23 @@ import org.reactivestreams.Publisher import sttp.capabilities import sttp.monad.MonadError import sttp.tapir.TapirFile -import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.NettyRequestBody import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber} +import sttp.tapir.server.netty.loom.* -private[netty] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, NoStreams] { +private[loom] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Id, OxStreams]: - override implicit val monad: MonadError[Id] = idMonad - override val streams: capabilities.Streams[NoStreams] = NoStreams + override given monad: MonadError[Id] = idMonad + override val streams: capabilities.Streams[OxStreams] = OxStreams override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] = SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes) override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = - serverRequest.underlying match { + serverRequest.underlying match case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes) case _ => () // Empty request - } override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = throw new UnsupportedOperationException() -} diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncToResponseBody.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncToResponseBody.scala new file mode 100644 index 0000000000..cf7ef6b12f --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/NettySyncToResponseBody.scala @@ -0,0 +1,47 @@ +package sttp.tapir.server.netty.loom.internal + +import _root_.ox.* +import io.netty.channel.ChannelHandlerContext +import sttp.capabilities +import sttp.model.HasHeaders +import sttp.monad.MonadError +import sttp.tapir.server.interpreter.ToResponseBody +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent.ReactiveWebSocketProcessorNettyResponseContent +import sttp.tapir.server.netty.internal.{NettyToResponseBody, RunAsync} +import sttp.tapir.server.netty.loom._ +import sttp.tapir.server.netty.loom.internal.ox.OxDispatcher +import sttp.tapir.* + +import java.nio.charset.Charset + +private[loom] class NettySyncToResponseBody(runAsync: RunAsync[Id], oxDispatcher: OxDispatcher)(using me: MonadError[Id]) + extends ToResponseBody[NettyResponse, OxStreams]: + + val delegate = new NettyToResponseBody(runAsync)(me) + + override val streams: capabilities.Streams[OxStreams] = OxStreams + + def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = + delegate.fromRawValue(v, headers, format, bodyType) + def fromStreamValue(v: streams.BinaryStream, headers: HasHeaders, format: CodecFormat, charset: Option[Charset]): NettyResponse = + throw new UnsupportedOperationException + + override def fromWebSocketPipe[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, ?, OxStreams] + ): NettyResponse = (ctx: ChannelHandlerContext) => + val channelPromise = ctx.newPromise() + new ReactiveWebSocketProcessorNettyResponseContent( + channelPromise, + ws.OxSourceWebSocketProcessor[REQ, RESP]( + oxDispatcher, + pipe.asInstanceOf[OxStreams.Pipe[REQ, RESP]], + o.asInstanceOf[WebSocketBodyOutput[OxStreams.Pipe[REQ, RESP], REQ, RESP, ?, OxStreams]], + ctx + ), + ignorePong = o.ignorePong, + autoPongOnPing = o.autoPongOnPing, + decodeCloseRequests = o.decodeCloseRequests, + autoPing = o.autoPing + ) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ox/OxDispatcher.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ox/OxDispatcher.scala new file mode 100644 index 0000000000..52db6ebcd6 --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ox/OxDispatcher.scala @@ -0,0 +1,31 @@ +package sttp.tapir.server.netty.loom.internal.ox + +import ox.* +import ox.channels.Actor + +/** A dispatcher that can start arbitrary forks. Useful when one needs to start an asynchronous task from a thread outside of an Ox scope. + * Normally Ox doesn't allow to start forks from other threads, for example in callbacks of external libraries. If you create an + * OxDispatcher inside a scope and pass it for potential handling on another thread, that thread can call + * {{{ + * dispatcher.runAsync { + * // code to be executed in a fork + * } { throwable => + * // error handling if the fork fails with an exception, this will be run on the Ox virtual thread as well + * } + * }}} + * WARNING! Dispatchers should only be used in special cases, where the proper structure of concurrency scopes cannot be preserved. One + * such example is integration with callback-based systems like Netty, which runs handler methods on its event loop thread. + * @param ox + * concurrency scope where a fork will be run, using a nested scope to isolate failures. + */ +private[loom] class OxDispatcher()(using ox: Ox): + private class Runner: + def runAsync(thunk: Ox ?=> Unit, onError: Throwable => Unit): Unit = + fork { + try supervised(thunk) + catch case e => onError(e) + }.discard + + private val actor = Actor.create(new Runner) + + def runAsync(thunk: Ox ?=> Unit)(onError: Throwable => Unit): Unit = actor.tell(_.runAsync(thunk, onError)) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/ChannelSubscription.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/ChannelSubscription.scala new file mode 100644 index 0000000000..6e4779b3dc --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/ChannelSubscription.scala @@ -0,0 +1,37 @@ +package sttp.tapir.server.netty.loom.internal.reactivestreams + +import org.reactivestreams.{Subscriber, Subscription} +import ox.* +import ox.channels.* + +/** Can be used together with an [[OxProcessor]] to read from a Source when there's demand. */ +private[loom] class ChannelSubscription[A]( + subscriber: Subscriber[? >: A], + source: Source[A] +) extends Subscription: + private val demands: Channel[Long] = Channel.unlimited[Long] + + def runBlocking(): Unit = + demands.foreach { demand => + var i = 0L + while (i < demand) + source.receiveOrClosed() match + case ChannelClosed.Done => + demands.doneOrClosed().discard + i = demand // break early + subscriber.onComplete() + case ChannelClosed.Error(e) => + demands.doneOrClosed().discard + i = demand + subscriber.onError(e) + case elem: A @unchecked => + i = i + 1 + subscriber.onNext(elem) + } + + override def cancel(): Unit = + demands.doneOrClosed().discard + + override def request(n: Long): Unit = + if n <= 0 then subscriber.onError(new IllegalArgumentException("ยง3.9: n must be greater than 0")) + else demands.send(n) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/OxProcessor.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/OxProcessor.scala new file mode 100644 index 0000000000..68d8ca1de0 --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/reactivestreams/OxProcessor.scala @@ -0,0 +1,79 @@ +package sttp.tapir.server.netty.loom.internal.reactivestreams + +import org.reactivestreams.Subscriber +import ox.* +import ox.channels.* +import org.reactivestreams.Subscription +import org.reactivestreams.Processor +import sttp.tapir.server.netty.loom.internal.ox.OxDispatcher +import sttp.tapir.server.netty.loom.OxStreams + +/** A reactive Processor, which is both a Publisher and a Subscriber + * + * @param oxDispatcher + * a dispatcher to which async tasks can be submitted (reading from a channel) + * @param pipeline + * user-defined processing pipeline expressed as an Ox Source => Source transformation + * @param wrapSubscriber + * an optional function allowing wrapping external subscribers, can be used to intercept onNext, onComplete and onError with custom + * handling. Can be just identity. + */ +private[loom] class OxProcessor[A, B]( + oxDispatcher: OxDispatcher, + pipeline: OxStreams.Pipe[A, B], + wrapSubscriber: Subscriber[? >: B] => Subscriber[? >: B] +) extends Processor[A, B]: + // Incoming requests are read from this subscription into an Ox Channel[A] + @volatile private var requestsSubscription: Subscription = _ + // An internal channel for holding incoming requests (`A`), will be wrapped with user's pipeline to produce responses (`B`) + private val channel = Channel.buffered[A](1) + + override def onError(reason: Throwable): Unit = + // As per rule 2.13, we need to throw a `java.lang.NullPointerException` if the `Throwable` is `null` + if reason == null then throw null + channel.errorOrClosed(reason).discard + + override def onNext(a: A): Unit = + if a == null then throw new NullPointerException("Element cannot be null") // Rule 2.13 + else + channel.sendOrClosed(a) match + case () => () + case _: ChannelClosed => + cancelSubscription() + onError(new IllegalStateException("onNext called when the channel is closed")) + + override def onSubscribe(s: Subscription): Unit = + if s == null then throw new NullPointerException("Subscription cannot be null") + else if requestsSubscription != null then s.cancel() // Rule 2.5: if onSubscribe is called twice, must cancel the second subscription + else + requestsSubscription = s + s.request(1) + + override def onComplete(): Unit = + channel.doneOrClosed().discard + + override def subscribe(subscriber: Subscriber[? >: B]): Unit = + if subscriber == null then throw new NullPointerException("Subscriber cannot be null") + val wrappedSubscriber = wrapSubscriber(subscriber) + oxDispatcher.runAsync { + val outgoingResponses: Source[B] = pipeline((channel: Source[A]).mapAsView { e => + requestsSubscription.request(1) + e + }) + val channelSubscription = new ChannelSubscription(wrappedSubscriber, outgoingResponses) + subscriber.onSubscribe(channelSubscription) + channelSubscription.runBlocking() // run the main loop which reads from the channel if there's demand + } { error => + wrappedSubscriber.onError(error) + onError(error) + } + + private def cancelSubscription() = + if requestsSubscription != null then + try requestsSubscription.cancel() + catch + case t: Throwable => + throw new IllegalStateException( + s"$requestsSubscription violated the Reactive Streams rule 3.15 by throwing an exception from cancel.", + t + ) diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ws/OxSourceWebSocketProcessor.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ws/OxSourceWebSocketProcessor.scala new file mode 100644 index 0000000000..044d91a9ae --- /dev/null +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/internal/ws/OxSourceWebSocketProcessor.scala @@ -0,0 +1,81 @@ +package sttp.tapir.server.netty.loom.internal.ws + +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, WebSocketCloseStatus, WebSocketFrame => NettyWebSocketFrame} +import org.reactivestreams.{Processor, Subscriber, Subscription} +import org.slf4j.LoggerFactory +import ox.* +import ox.channels.{ChannelClosedException, Source} +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._ +import sttp.tapir.server.netty.loom.OxStreams +import sttp.tapir.server.netty.loom.internal.ox.OxDispatcher +import sttp.tapir.server.netty.loom.internal.reactivestreams.OxProcessor +import sttp.tapir.{DecodeResult, WebSocketBodyOutput} +import sttp.ws.WebSocketFrame +import java.io.IOException + +private[loom] object OxSourceWebSocketProcessor: + + def apply[REQ, RESP]( + oxDispatcher: OxDispatcher, + pipe: OxStreams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[OxStreams.Pipe[REQ, RESP], REQ, RESP, ?, OxStreams], + ctx: ChannelHandlerContext + ): Processor[NettyWebSocketFrame, NettyWebSocketFrame] = + val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = + (source: Source[NettyWebSocketFrame]) => { + pipe( + optionallyConcatenateFrames( + source + .mapAsView { f => + val sttpFrame = nettyFrameToFrame(f) + f.release() + sttpFrame + }, + o.concatenateFragmentedFrames + ) + .mapAsView(f => + o.requests.decode(f) match { + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + case x: DecodeResult.Value[REQ] @unchecked => x.v + } + ) + ) + .mapAsView(r => frameToNettyFrame(o.responses.encode(r))) + } + // We need this kind of interceptor to make Netty reply correctly to closed channel or error + def wrapSubscriberWithNettyCallback[B](sub: Subscriber[? >: B]): Subscriber[? >: B] = new Subscriber[B] { + private val logger = LoggerFactory.getLogger(getClass.getName) + override def onSubscribe(s: Subscription): Unit = sub.onSubscribe(s) + override def onNext(t: B): Unit = sub.onNext(t) + override def onError(t: Throwable): Unit = + t match + case ChannelClosedException.Error(e: IOException) => + // Connection reset? + logger.info("Web Socket channel closed abnormally", e) + case e => + logger.error("Web Socket channel closed abnormally", e) + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Internal Server Error")) + sub.onError(t) + override def onComplete(): Unit = + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Bye")) + sub.onComplete() + } + new OxProcessor(oxDispatcher, frame2FramePipe, wrapSubscriberWithNettyCallback) + + private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] = + if doConcatenate then + type Accumulator = Option[Either[Array[Byte], String]] + s.mapStateful(() => None: Accumulator) { + case (None, f: WebSocketFrame.Ping) => (None, Some(f)) + case (None, f: WebSocketFrame.Pong) => (None, Some(f)) + case (None, f: WebSocketFrame.Close) => (None, Some(f)) + case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + }.collectAsView { case Some(f: WebSocketFrame) => f } + else s diff --git a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/loom.scala b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/loom.scala index 82c2c015cc..f5d3e8bc3d 100644 --- a/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/loom.scala +++ b/server/netty-server/loom/src/main/scala/sttp/tapir/server/netty/loom/loom.scala @@ -2,7 +2,7 @@ package sttp.tapir.server.netty import sttp.monad.MonadError -package object loom { +package object loom: type Id[X] = X type IdRoute = Route[Id] @@ -17,4 +17,3 @@ package object loom { try f finally e } -} diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncServerTest.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncServerTest.scala index 08761f8f95..621d1c5857 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncServerTest.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncServerTest.scala @@ -1,35 +1,141 @@ package sttp.tapir.server.netty.loom -import cats.effect.{IO, Resource} +import cats.data.NonEmptyList +import cats.effect.unsafe.implicits.global +import cats.effect.IO import io.netty.channel.nio.NioEventLoopGroup -import org.scalatest.EitherValues -import sttp.tapir.server.netty.internal.FutureUtil.nettyFutureToScala -import sttp.tapir.server.tests._ -import sttp.tapir.tests.{Test, TestSuite} +import org.scalactic.source.Position +import org.scalatest.compatible.Assertion +import org.scalatest.funsuite.AsyncFunSuite +import org.scalatest.BeforeAndAfterAll +import org.slf4j.LoggerFactory +import ox.* +import ox.channels.Source +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3.* +import sttp.model.* +import sttp.tapir.PublicEndpoint +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.tests.* +import sttp.tapir.tests.* -import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future + +class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { + + val (backend, stopBackend) = backendResource.allocated.unsafeRunSync() + def testNameFilter: Option[String] = None // define to run a single test (temporarily for debugging) + { + val eventLoopGroup = new NioEventLoopGroup() + + val interpreter = new NettySyncTestServerInterpreter(eventLoopGroup) + val createServerTest = new NettySyncCreateServerTest(backend, interpreter) + val sleeper: Sleeper[Id] = (duration: FiniteDuration) => Thread.sleep(duration.toMillis) + + val tests = + new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false) + .tests() ++ + new ServerGracefulShutdownTests(createServerTest, sleeper).tests() ++ + new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) { + override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = ox ?=> in => in.map(f) + override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Source.empty + }.tests() + + tests.foreach { t => + if (testNameFilter.forall(filter => t.name.contains(filter))) { + implicit val pos: Position = t.pos + + this.test(t.name)(t.f()) + } + } + } + override protected def afterAll(): Unit = { + stopBackend.unsafeRunSync() + super.afterAll() + } +} + +class NettySyncCreateServerTest( + backend: SttpBackend[IO, Fs2Streams[IO] & WebSockets], + interpreter: NettySyncTestServerInterpreter +) extends CreateServerTest[Id, OxStreams & WebSockets, NettySyncServerOptions, IdRoute] { + + private val logger = LoggerFactory.getLogger(getClass.getName) + + override def testServer[I, E, O]( + e: PublicEndpoint[I, E, O, OxStreams & WebSockets], + testNameSuffix: String = "", + interceptors: Interceptors = identity + )( + fn: I => Id[Either[E, O]] + )(runTest: (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion]): Test = { + testServerLogic(e.serverLogic(fn), testNameSuffix, interceptors)(runTest) + } + + override def testServerLogic( + e: ServerEndpoint[OxStreams & WebSockets, Id], + testNameSuffix: String = "", + interceptors: Interceptors = identity + )( + runTest: (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion] + ): Test = { + testServerLogicWithStop(e, testNameSuffix, interceptors)((_: IO[Unit]) => runTest) + } + + override def testServerLogicWithStop( + e: ServerEndpoint[OxStreams & WebSockets, Id], + testNameSuffix: String = "", + interceptors: Interceptors = identity, + gracefulShutdownTimeout: Option[FiniteDuration] = None + )( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion] + ): Test = { + Test( + e.showDetail + (if (testNameSuffix == "") "" else " " + testNameSuffix) + ) { + supervised { + val binding = interpreter.scopedServerWithInterceptorsStop(e, interceptors, gracefulShutdownTimeout) + val assertion: Assertion = + runTest(IO.blocking(binding.stop()))(backend, uri"http://localhost:${binding.port}") + .guarantee(IO(logger.info(s"Test completed on port ${binding.port}"))) + .unsafeRunSync() + Future.successful(assertion) + } + } + } + + override def testServerWithStop(name: String, rs: => NonEmptyList[IdRoute], gracefulShutdownTimeout: Option[FiniteDuration])( + runTest: IO[Unit] => (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion] + ): Test = throw new UnsupportedOperationException + + override def testServer(name: String, rs: => NonEmptyList[IdRoute])( + runTest: (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion] + ): Test = + Test(name) { + supervised { + val binding = interpreter.scopedServerWithRoutesStop(rs) + val assertion: Assertion = + runTest(backend, uri"http://localhost:${binding.port}") + .guarantee(IO(logger.info(s"Test completed on port ${binding.port}"))) + .unsafeRunSync() + Future.successful(assertion) + } + } -class NettySyncServerTest extends TestSuite with EitherValues { - override def tests: Resource[IO, List[Test]] = - backendResource.flatMap { backend => - Resource - .make(IO.delay { - val eventLoopGroup = new NioEventLoopGroup() - - val interpreter = new NettySyncTestServerInterpreter(eventLoopGroup) - val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val sleeper: Sleeper[Id] = (duration: FiniteDuration) => Thread.sleep(duration.toMillis) - - val tests = - new AllServerTests(createServerTest, interpreter, backend, staticContent = false, multipart = false) - .tests() ++ - new ServerGracefulShutdownTests(createServerTest, sleeper).tests() - - (tests, eventLoopGroup) - }) { case (_, eventLoopGroup) => - IO.fromFuture(IO.delay(nettyFutureToScala(eventLoopGroup.shutdownGracefully()): Future[_])).void - } - .map { case (tests, _) => tests } + def testServer(name: String, es: NonEmptyList[ServerEndpoint[OxStreams & WebSockets, Id]])( + runTest: (SttpBackend[IO, Fs2Streams[IO] & WebSockets], Uri) => IO[Assertion] + ): Test = { + Test(name) { + supervised { + val binding = interpreter.scopedServerWithStop(es) + val assertion: Assertion = + runTest(backend, uri"http://localhost:${binding.port}") + .guarantee(IO(logger.info(s"Test completed on port ${binding.port}"))) + .unsafeRunSync() + Future.successful(assertion) + } } + } } diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncTestServerInterpreter.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncTestServerInterpreter.scala index a5bd2159e2..68219d843a 100644 --- a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncTestServerInterpreter.scala +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/NettySyncTestServerInterpreter.scala @@ -3,18 +3,30 @@ package sttp.tapir.server.netty.loom import cats.data.NonEmptyList import cats.effect.{IO, Resource} import io.netty.channel.nio.NioEventLoopGroup +import internal.ox.OxDispatcher +import ox.* import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.NettyConfig import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port import scala.concurrent.duration.FiniteDuration +import sttp.capabilities.WebSockets class NettySyncTestServerInterpreter(eventLoopGroup: NioEventLoopGroup) - extends TestServerInterpreter[Id, Any, NettySyncServerOptions, IdRoute] { - override def route(es: List[ServerEndpoint[Any, Id]], interceptors: Interceptors): IdRoute = { + extends TestServerInterpreter[Id, OxStreams with WebSockets, NettySyncServerOptions, IdRoute] { + override def route(es: List[ServerEndpoint[OxStreams with WebSockets, Id]], interceptors: Interceptors): IdRoute = { val serverOptions: NettySyncServerOptions = interceptors(NettySyncServerOptions.customiseInterceptors).options - NettySyncServerInterpreter(serverOptions).toRoute(es) + supervised { // not a correct way, but this method is only used in a few tests which don't test anything related to scopes + NettySyncServerInterpreter(serverOptions).toRoute(es, new OxDispatcher()) + } + } + + def route(es: List[ServerEndpoint[OxStreams with WebSockets, Id]], interceptors: Interceptors)(using Ox): IdRoute = { + val serverOptions: NettySyncServerOptions = interceptors(NettySyncServerOptions.customiseInterceptors).options + supervised { // not a correct way, but this method is only used in a few tests which don't test anything related to scopes + NettySyncServerInterpreter(serverOptions).toRoute(es, new OxDispatcher()) + } } override def serverWithStop( @@ -25,9 +37,41 @@ class NettySyncTestServerInterpreter(eventLoopGroup: NioEventLoopGroup) NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) val options = NettySyncServerOptions.default - val bind = IO.blocking(NettySyncServer(options, customizedConfig).addRoutes(routes.toList).start()) + val bind = IO.blocking(NettySyncServer(options, customizedConfig).start(routes.toList)) Resource .make(bind.map(b => (b.port, IO.blocking(b.stop())))) { case (_, stop) => stop } } + + def scopedServerWithRoutesStop( + routes: NonEmptyList[IdRoute], + gracefulShutdownTimeout: Option[FiniteDuration] = None + )(using Ox): NettySyncServerBinding = + val config = + NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) + val options = NettySyncServerOptions.default + useInScope(NettySyncServer(options, customizedConfig).start(routes.toList))(_.stop()) + + def scopedServerWithInterceptorsStop( + endpoint: ServerEndpoint[OxStreams with WebSockets, Id], + interceptors: Interceptors = identity, + gracefulShutdownTimeout: Option[FiniteDuration] = None + )(using Ox): NettySyncServerBinding = + val config = + NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) + val options = interceptors(NettySyncServerOptions.customiseInterceptors).options + useInScope(NettySyncServer(customizedConfig).addEndpoint(endpoint, options).start())(_.stop()) + + def scopedServerWithStop( + endpoints: NonEmptyList[ServerEndpoint[OxStreams with WebSockets, Id]], + gracefulShutdownTimeout: Option[FiniteDuration] = None + )(using Ox): NettySyncServerBinding = + val config = + NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose.noGracefulShutdown + val customizedConfig = gracefulShutdownTimeout.map(config.withGracefulShutdownTimeout).getOrElse(config) + val options = NettySyncServerOptions.default + val interpreter = NettySyncServerInterpreter(options) + useInScope(NettySyncServer(options, customizedConfig).addEndpoints(endpoints.toList).start())(_.stop()) } diff --git a/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/perf/NettySyncServerRunner.scala b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/perf/NettySyncServerRunner.scala new file mode 100644 index 0000000000..6455d733b0 --- /dev/null +++ b/server/netty-server/loom/src/test/scala/sttp/tapir/server/netty/loom/perf/NettySyncServerRunner.scala @@ -0,0 +1,98 @@ +package sttp.tapir.server.netty.loom.perf + +import ox.* +import ox.channels.* +import sttp.tapir.server.netty.loom.NettySyncServerOptions +import sttp.tapir.server.netty.loom.NettySyncServerBinding +import sttp.tapir.server.netty.loom.NettySyncServer + +import sttp.tapir.* +import sttp.tapir.server.netty.loom.Id +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.model.EndpointExtensions.* +import sttp.tapir.server.netty.loom.OxStreams +import sttp.tapir.Endpoint +import sttp.capabilities.WebSockets +import scala.concurrent.duration._ + +object NettySyncServerRunner { + val LargeInputSize = 5 * 1024L * 1024L + val WebSocketSingleResponseLag = 100.millis + + type EndpointGen = Int => PublicEndpoint[_, String, String, Any] + type ServerEndpointGen[F[_]] = Int => ServerEndpoint[Any, F] + def serverEndpoints[F[_]](reply: String => F[String]): List[ServerEndpointGen[F]] = { + List( + { (n: Int) => + endpoint.get + .in("path" + n.toString) + .in(path[Int]("id")) + .out(stringBody) + .serverLogicSuccess { id => + reply((id + n).toString) + } + }, + { (n: Int) => + endpoint.post + .in("path" + n.toString) + .in(stringBody) + .maxRequestBodyLength(LargeInputSize + 1024L) + .out(stringBody) + .serverLogicSuccess { (body: String) => + reply(s"Ok [$n], string length = ${body.length}") + } + }, + { (n: Int) => + endpoint.post + .in("pathBytes" + n.toString) + .in(byteArrayBody) + .maxRequestBodyLength(LargeInputSize + 1024L) + .out(stringBody) + .serverLogicSuccess { (body: Array[Byte]) => + reply(s"Ok [$n], bytes length = ${body.length}") + } + } + ) + } + + val wsBaseEndpoint = endpoint.get.in("ws" / "ts") + + val wsPipe: OxStreams.Pipe[Long, Long] = { in => + fork { + in.drain() + } + Source.tick(WebSocketSingleResponseLag).map(_ => System.currentTimeMillis()) + } + + val wsEndpoint: Endpoint[Unit, Unit, Unit, OxStreams.Pipe[Long, Long], OxStreams with WebSockets] = wsBaseEndpoint + .out( + webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](OxStreams) + .concatenateFragmentedFrames(false) + .autoPongOnPing(false) + .ignorePong(true) + .autoPing(None) + ) + val wsServerEndpoint = wsEndpoint.serverLogicSuccess[Id](_ => wsPipe) + + val endpoints = genEndpointsId(1) + + def main(args: Array[String]): Unit = { + val declaredPort = 8080 + val declaredHost = "0.0.0.0" + + supervised { + val serverBinding: NettySyncServerBinding = useInScope( + NettySyncServer(NettySyncServerOptions.customiseInterceptors.options) + .port(declaredPort) + .host(declaredHost) + .addEndpoints(wsServerEndpoint :: endpoints) + .start() + )(_.stop()) + println(s"Netty running with binding: $serverBinding") + never + } + } + def genServerEndpoints[F[_]](routeCount: Int)(reply: String => F[String]): List[ServerEndpoint[Any, F]] = + serverEndpoints[F](reply).flatMap(gen => (0 to routeCount).map(i => gen(i))) + def genEndpointsId(count: Int): List[ServerEndpoint[Any, Id]] = genServerEndpoints[Id](count)(x => x: Id[String]) +}