Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSockets for netty-cats #3628

Merged
merged 31 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
28455d2
Initial implementation
kciesielski Mar 20, 2024
5f61671
Add comments
kciesielski Mar 25, 2024
6b92093
Remove TODO
kciesielski Mar 25, 2024
ecc0f98
Fix cancellation and error handling
kciesielski Mar 25, 2024
48ed4f7
A few more tweaks to cancellation and error handling
kciesielski Mar 25, 2024
baa9820
Add volatiles where needed
kciesielski Mar 26, 2024
d9b9a69
Auto-ping
kciesielski Mar 26, 2024
4d1e11d
Adjust netty-zio
kciesielski Mar 26, 2024
f9bb21e
Don't decode Ping
kciesielski Mar 26, 2024
f420423
Remove cancelation test
kciesielski Mar 26, 2024
aabdf15
Adjust error handling test to http4s
kciesielski Mar 26, 2024
da2860f
Correctly read bytes from Netty Frames
kciesielski Mar 26, 2024
9eeb2ac
Parameterize tests
kciesielski Mar 27, 2024
b4e7192
Remove println
kciesielski Mar 27, 2024
83fa2f8
Fix memory leak
kciesielski Mar 27, 2024
4705519
Cleanup and minor tweaks
kciesielski Mar 27, 2024
929d708
Documentation
kciesielski Mar 27, 2024
8497790
Organize handlers
kciesielski Mar 27, 2024
2838ece
Handle hanshake for regular endpoints with 400
kciesielski Mar 27, 2024
e9a213f
Extract methods and reply 400 on regular endpoints
kciesielski Mar 27, 2024
ae1af27
Add capability to other methods
kciesielski Mar 27, 2024
f65b0f2
Explicitly close the channel
kciesielski Mar 28, 2024
76a978d
Fix typo
kciesielski Mar 28, 2024
d811daf
Review fixes
kciesielski Mar 28, 2024
8e2efc7
Move ws-specific stuff to its own package
kciesielski Mar 28, 2024
9cf4e42
Add handlers in a list
kciesielski Mar 28, 2024
b9db8dd
More review fixes
kciesielski Mar 28, 2024
ba7fb91
Use parasitic EC
kciesielski Mar 28, 2024
ae65609
Improvements after code review
kciesielski Mar 29, 2024
d2a75d0
Handle handshake in NettyServerHandler
kciesielski Mar 29, 2024
c196742
Remove unneeded handler name
kciesielski Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,7 @@ lazy val documentation: ProjectMatrix = (projectMatrix in file("generated-doc"))
sprayJson,
http4sClient,
http4sServerZio,
nettyServerCats,
sttpClient,
playClient,
sttpStubServer,
Expand Down
65 changes: 65 additions & 0 deletions doc/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,71 @@ NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None)
NettyFutureServer(NettyConfig.default.socketBacklog(256))
```

## Web sockets

The netty-cats interpreter supports web sockets, with pipes of type `fs2.Pipe[F, REQ, RESP]`. See [web sockets](../endpoint/websockets.md)
for more details.

To create a web socket endpoint, use Tapir's `out(webSocketBody)` output type:

```scala mdoc:compile-only
import cats.effect.kernel.Resource
import cats.effect.{IO, ResourceApp}
import cats.syntax.all._
import fs2.Pipe
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir._
import sttp.tapir.server.netty.cats.NettyCatsServer
import sttp.ws.WebSocketFrame

import scala.concurrent.duration._

object WebSocketsNettyCatsServer extends ResourceApp.Forever {

// Web socket endpoint
val wsEndpoint =
endpoint.get
.in("ws")
.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO])
.concatenateFragmentedFrames(false) // All these options are supported by tapir-netty
.ignorePong(true)
.autoPongOnPing(true)
.decodeCloseRequests(false)
.decodeCloseResponses(false)
.autoPing(Some((10.seconds, WebSocketFrame.Ping("ping-content".getBytes))))
)

// Your processor transforming a stream of requests into a stream of responses
val pipe: Pipe[IO, String, String] = requestStream => requestStream.evalMap(str => IO.pure(str.toUpperCase))
// Alternatively, requests can be ignored and the backend can be turned into a stream emitting frames to the client:
// val pipe: Pipe[IO, String, String] = requestStream => someDataEmittingStream.concurrently(requestStream.as(()))

val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ => IO.pure(pipe))

// A regular /GET endpoint
val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] =
endpoint.get.in("hello").in(query[String]("name")).out(stringBody)

val helloWorldServerEndpoint = helloWorldEndpoint
.serverLogicSuccess(name => IO.pure(s"Hello, $name!"))

override def run(args: List[String]) = NettyCatsServer
.io()
.flatMap { server =>
Resource
.make(
server
.port(8080)
.host("localhost")
.addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint))
.start()
)(_.stop())
.as(())
}
}
```

## Graceful shutdown

A Netty server can be gracefully closed using the function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,36 @@ package sttp.tapir.perf.netty.cats
import cats.effect.IO
import cats.effect.kernel.Resource
import cats.effect.std.Dispatcher
import fs2.Stream
import sttp.tapir.{CodecFormat, webSocketBody}
import sttp.tapir.perf.Common._
import sttp.tapir.perf.apis._
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.cats.NettyCatsServer
import sttp.tapir.server.netty.cats.NettyCatsServerOptions
import sttp.ws.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams

object Tapir extends Endpoints
import scala.concurrent.duration._

object NettyCats {
object Tapir extends Endpoints {
val wsResponseStream = Stream.fixedRate[IO](WebSocketSingleResponseLag, dampen = false)
val wsEndpoint = wsBaseEndpoint
.out(
webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](Fs2Streams[IO])
.concatenateFragmentedFrames(false)
.autoPongOnPing(false)
.ignorePong(true)
.autoPing(None)
)
}

object NettyCats {
val wsServerEndpoint = Tapir.wsEndpoint.serverLogicSuccess(_ =>
IO.pure { (in: Stream[IO, Long]) =>
Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(()))
}
)
def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = {
val declaredPort = Port
val declaredHost = "0.0.0.0"
Expand All @@ -25,7 +45,7 @@ object NettyCats {
server
.port(declaredPort)
.host(declaredHost)
.addEndpoints(endpoints)
.addEndpoints(wsServerEndpoint :: endpoints)
.start()
)(binding => binding.stop())
} yield ()).allocated.map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++
new ServerWebSocketTests(createServerTest, AkkaStreams) {
new ServerWebSocketTests(
createServerTest,
AkkaStreams,
autoPing = false,
failingPipe = true,
handlePong = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f)
override def emptyPipe[A, B]: Flow[A, B, Any] = Flow.fromSinkAndSource(Sink.ignore, Source.empty)
}.tests() ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import cats.effect._
import cats.effect.unsafe.implicits.global
import cats.syntax.all._
import fs2.Pipe
import fs2.Stream
import org.http4s.blaze.server.BlazeServerBuilder
import org.http4s.server.Router
import org.http4s.server.ContextMiddleware
Expand Down Expand Up @@ -138,7 +139,13 @@ class Http4sServerTest[R >: Fs2Streams[IO] with WebSockets] extends TestSuite wi

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++
new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) {
new ServerWebSocketTests(
createServerTest,
Fs2Streams[IO],
autoPing = true,
failingPipe = true,
handlePong = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: Pipe[IO, A, B] = _ => fs2.Stream.empty
}.tests() ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ class ZHttp4sServerTest extends TestSuite with OptionValues {

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++
new ServerWebSocketTests(createServerTest, ZioStreams) {
new ServerWebSocketTests(
createServerTest,
ZioStreams,
autoPing = true,
failingPipe = false,
handlePong = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
}.tests() ++
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package sttp.tapir.server.netty.cats.internal

import scala.concurrent.ExecutionContext

object ExecutionContexts {
val sameThread: ExecutionContext = new ExecutionContext {
override def execute(runnable: Runnable): Unit = runnable.run()

override def reportFailure(cause: Throwable): Unit =
ExecutionContext.defaultReporter(cause)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package sttp.tapir.server.netty.cats.internal

import scala.concurrent.ExecutionContext

object ExecutionContexts {
val sameThread: ExecutionContext = ExecutionContext.parasitic
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.netty.channel._
import io.netty.channel.group.{ChannelGroup, DefaultChannelGroup}
import io.netty.channel.unix.DomainSocketAddress
import io.netty.util.concurrent.DefaultEventExecutor
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.integ.cats.effect.CatsMonadError
Expand All @@ -25,13 +26,16 @@ import scala.concurrent.Future
import scala.concurrent.duration._

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] =
def addEndpoint(se: ServerEndpoint[Fs2Streams[F] with WebSockets, F]): NettyCatsServer[F] = addEndpoints(List(se))
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F] with WebSockets, overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] =
addEndpoints(List(se), overrideOptions)
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute(
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute(
kciesielski marked this conversation as resolved.
Show resolved Hide resolved
NettyCatsServerInterpreter(options).toRoute(ses)
)
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute(
def addEndpoints(
ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]],
overrideOptions: NettyCatsServerOptions[F]
): NettyCatsServer[F] = addRoute(
NettyCatsServerInterpreter(overrideOptions).toRoute(ses)
)

Expand Down Expand Up @@ -74,7 +78,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
val channelFuture =
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader),
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader, config.isSsl),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve
import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _}
import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}
import sttp.capabilities.WebSockets

trait NettyCatsServerInterpreter[F[_]] {
implicit def async: Async[F]
def nettyServerOptions: NettyCatsServerOptions[F]

def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = {
def toRoute(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): Route[F] = {

implicit val monad: MonadError[F] = new CatsMonadError[F]
val runAsync = new RunAsync[F] {
Expand All @@ -31,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] {
val createFile = nettyServerOptions.createFile
val deleteFile = nettyServerOptions.deleteFile

val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]](
val serverInterpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, NettyResponse, Fs2Streams[F]](
FilterServerEndpoints(ses),
new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
package sttp.tapir.server.netty.cats.internal

import cats.effect.kernel.{Async, Sync}
import cats.effect.std.Dispatcher
import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher}
import fs2.io.file.{Files, Flags, Path}
import fs2.{Chunk, Pipe}
import io.netty.buffer.Unpooled
import io.netty.channel.{ChannelFuture, ChannelHandlerContext}
import io.netty.handler.codec.http.websocketx._
import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent}
import org.reactivestreams.Publisher
import sttp.tapir.FileRange
import org.reactivestreams.{Processor, Publisher}
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.server.netty.internal._
import sttp.tapir.{FileRange, WebSocketBodyOutput}

import java.io.InputStream
import cats.effect.std.Dispatcher
import sttp.capabilities.fs2.Fs2Streams
import fs2.io.file.Path
import fs2.io.file.Files
import cats.effect.kernel.Async
import fs2.io.file.Flags
import fs2.interop.reactivestreams.StreamUnicastPublisher
import cats.effect.kernel.Sync
import fs2.Chunk
import fs2.interop.reactivestreams.StreamSubscriber

object Fs2StreamCompatible {

Expand Down Expand Up @@ -68,6 +66,26 @@ object Fs2StreamCompatible {
override def emptyStream: streams.BinaryStream =
fs2.Stream.empty

override def asWsProcessor[REQ, RESP](
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]],
ctx: ChannelHandlerContext
): Processor[WebSocketFrame, WebSocketFrame] = {
val wsCompletedPromise = ctx.newPromise()
wsCompletedPromise.addListener((f: ChannelFuture) => {
// A special callback that has to be used when a SteramSubscription cancels or fails.
// This can happen in case of errors in the pipeline which are not signalled correctly,
// like throwing exceptions directly.
// Without explicit Close frame a client may hang on waiting and not knowing about closed channel.
if (f.isCancelled) {
val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Canceled"))
} else if (!f.isSuccess) {
val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Error"))
}
})
new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o, wsCompletedPromise)
}

private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) =
fs2.io.readInputStream(
Sync[F].blocking(inputStream()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import sttp.tapir.TapirFile
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.capabilities.WebSockets

private[cats] class NettyCatsRequestBody[F[_]: Async](
val createFile: ServerRequest => F[TapirFile],
Expand All @@ -24,7 +25,8 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] =
(toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream])
(toStream(serverRequest, maxBytes)
.asInstanceOf[streamCompatible.streams.BinaryStream])
.through(
Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath))
)
Expand Down
Loading
Loading