Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ox, add WS chat example #4019

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/server/netty.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ object WebSocketsNettyCatsServer extends ResourceApp.Forever {
In the Loom-based backend, Tapir uses [Ox](https://ox.softwaremill.com) to manage concurrency, and your transformation pipeline should be represented as `Ox ?=> Source[A] => Source[B]`. Any forks started within this function will be run under a safely isolated internal scope.
See [examples/websocket/WebSocketNettySyncServer.scala](https://github.com/softwaremill/tapir/blob/master/examples/src/main/scala/sttp/tapir/examples/websocket/WebSocketNettySyncServer.scala) for a full example.

```{note}
The pipeline transform a source of incoming web socket messages (received from the client), into a source of outgoing web socket messages (which will be sent to the client), within some concurrency scope. Once the incoming source is done, the client has closed the connection. In that case, remember to close the outgoing source as well: otherwise the scope will leak and won't be closed. An error will be logged if the outgoing channel is not closed within a timeout after a close frame is received.
```

## Graceful shutdown

A Netty server can be gracefully closed using the function `NettyFutureServerBinding.stop()` (and analogous functions available in Cats and ZIO bindings). This function ensures that the server will wait at most 10 seconds for in-flight requests to complete, while rejecting all new requests with 503 during this period. Afterwards, it closes all server resources.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// {cat=WebSocket; effects=Direct; server=Netty}: A WebSocket chat across multiple clients connected to the same server

//> using dep com.softwaremill.sttp.tapir::tapir-core:1.11.2
//> using dep com.softwaremill.sttp.tapir::tapir-netty-server-sync:1.11.2
//> using dep com.softwaremill.ox::core:0.3.7

package sttp.tapir.examples.websocket

import ox.channels.{Actor, ActorRef, Channel, ChannelClosed, Default, DefaultResult, selectOrClosed}
import ox.{ExitCode, IO, Ox, OxApp, fork, never, releaseAfterScope}
import sttp.tapir.*
import sttp.tapir.CodecFormat.*
import sttp.tapir.server.netty.sync.{NettySyncServer, OxStreams}

import java.util.UUID

type ChatMemberId = UUID

case class ChatMember(id: ChatMemberId, channel: Channel[Message])
object ChatMember:
def create: ChatMember = ChatMember(UUID.randomUUID(), Channel.bufferedDefault[Message])

class ChatRoom:
private var members: Map[ChatMemberId, ChatMember] = Map()

def connected(m: ChatMember): Unit =
members = members + (m.id -> m)
println(s"Connected: ${m.id}, number of members: ${members.size}")

def disconnected(m: ChatMember): Unit =
members = members - m.id
println(s"Disconnected: ${m.id}, number of members: ${members.size}")

def incoming(message: Message): Unit =
println(s"Broadcasting: ${message.v}")
members = members.flatMap { (id, member) =>
selectOrClosed(member.channel.sendClause(message), Default(())) match
case member.channel.Sent() => Some((id, member))
case _: ChannelClosed =>
println(s"Channel of member $id closed, removing from members")
None
case DefaultResult(_) =>
println(s"Buffer for member $id full, not sending message")
Some((id, member))
}

//

case class Message(v: String) // could be more complex, e.g. JSON including nickname + message
given Codec[String, Message, TextPlain] = Codec.string.map(Message(_))(_.v)

val chatEndpoint = endpoint.get
.in("chat")
.out(webSocketBody[Message, TextPlain, Message, TextPlain](OxStreams))

def chatProcessor(a: ActorRef[ChatRoom]): OxStreams.Pipe[Message, Message] =
incoming => {
val member = ChatMember.create

a.tell(_.connected(member))

fork {
incoming.foreach { msg =>
a.tell(_.incoming(msg))
}
// all incoming messages are processed (= client closed), completing the outgoing channel as well
member.channel.done()
}

// however the scope ends (client close or error), we need to notify the chat room
releaseAfterScope {
a.tell(_.disconnected(member))
}

member.channel
}

object WebSocketChatNettySyncServer extends OxApp:
override def run(args: Vector[String])(using Ox, IO): ExitCode =
val chatActor = Actor.create(new ChatRoom)
val chatServerEndpoint = chatEndpoint.handleSuccess(_ => chatProcessor(chatActor))
val binding = NettySyncServer().addEndpoint(chatServerEndpoint).start()
releaseAfterScope(binding.stop())
never
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import ox.channels.*
import sttp.capabilities.WebSockets
import sttp.tapir.*
import sttp.tapir.server.netty.sync.OxStreams
import sttp.tapir.server.netty.sync.OxStreams.Pipe // alias for Ox ?=> Source[A] => Source[B]
import sttp.tapir.server.netty.sync.OxStreams.Pipe
import sttp.tapir.server.netty.sync.NettySyncServer
import sttp.ws.WebSocketFrame

import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.duration.*

object WebSocketNettySyncServer:
Expand All @@ -33,17 +34,21 @@ object WebSocketNettySyncServer:

// Your processor transforming a stream of requests into a stream of responses
val wsPipe: Pipe[String, String] = requestStream => requestStream.map(_.toUpperCase)
// Alternatively, requests and responses can be treated separately, for example to emit frames to the client from another source:

// Alternative logic (not used here): requests and responses can be treated separately, for example to emit frames
// to the client from another source.
val wsPipe2: Pipe[String, String] = { in =>
val running = new AtomicBoolean(true) // TODO use https://github.com/softwaremill/ox/issues/209 once available
fork {
in.drain() // read and ignore requests
running.set(false) // stopping the responses
}
// emit periodic responses
Source.tick(1.second).map(_ => System.currentTimeMillis()).map(_.toString)
Source.tick(1.second).takeWhile(_ => running.get()).map(_ => System.currentTimeMillis()).map(_.toString)
}

// The WebSocket endpoint, builds the pipeline in serverLogicSuccess
val wsServerEndpoint = wsEndpoint.handleSuccess(_ => wsPipe)
val wsServerEndpoint = wsEndpoint.handleSuccess(_ => wsPipe2)

// A regular /GET endpoint
val helloWorldEndpoint =
Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object Versions {
val json4s = "4.0.7"
val metrics4Scala = "4.3.2"
val nettyReactiveStreams = "3.0.2"
val ox = "0.3.1"
val ox = "0.3.7"
val reactiveStreams = "1.0.4"
val sprayJson = "1.3.6"
val scalaCheck = "1.18.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,53 +1,63 @@
package sttp.tapir.server.netty.sync.internal.ws

import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, WebSocketCloseStatus, WebSocketFrame => NettyWebSocketFrame}
import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, WebSocketCloseStatus, WebSocketFrame as NettyWebSocketFrame}
import org.reactivestreams.{Processor, Subscriber, Subscription}
import org.slf4j.LoggerFactory
import ox.*
import ox.channels.{ChannelClosedException, Source}
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters._
import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters.*
import sttp.tapir.server.netty.sync.OxStreams
import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher
import sttp.tapir.server.netty.sync.internal.reactivestreams.OxProcessor
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.WebSocketFrame

import java.io.IOException
import java.util.concurrent.Semaphore

import scala.concurrent.duration.*

private[sync] object OxSourceWebSocketProcessor:
private val logger = LoggerFactory.getLogger(getClass.getName)
private val outgoingCloseAfterCloseTimeout = 1.second

def apply[REQ, RESP](
oxDispatcher: OxDispatcher,
pipe: OxStreams.Pipe[REQ, RESP],
processingPipe: OxStreams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[OxStreams.Pipe[REQ, RESP], REQ, RESP, ?, OxStreams],
ctx: ChannelHandlerContext
): Processor[NettyWebSocketFrame, NettyWebSocketFrame] =
val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] =
(source: Source[NettyWebSocketFrame]) => {
pipe(
optionallyConcatenateFrames(o.concatenateFragmentedFrames)(
takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests)(
source
.mapAsView { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
)
)
.mapAsView(f =>
o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case x: DecodeResult.Value[REQ] @unchecked => x.v
}
)
)
def decodeFrame(f: WebSocketFrame): REQ = o.requests.decode(f) match {
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
case x: DecodeResult.Value[REQ] @unchecked => x.v
}

val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = ox ?=>
val closeSignal = new Semaphore(0)
(incoming: Source[NettyWebSocketFrame]) =>
val outgoing = incoming
.mapAsView { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
.pipe(takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests, closeSignal))
.pipe(optionallyConcatenateFrames(o.concatenateFragmentedFrames))
.mapAsView(decodeFrame)
.pipe(processingPipe)
.mapAsView(r => frameToNettyFrame(o.responses.encode(r)))
}

// when the client closes the connection, we need to close the outgoing channel as well - this needs to be
// done in the client's pipeline code; monitoring that this happens within a timeout after the close happens
monitorOutgoingClosedAfterClientClose(closeSignal, outgoing)

outgoing
end frame2FramePipe

// We need this kind of interceptor to make Netty reply correctly to closed channel or error
def wrapSubscriberWithNettyCallback[B](sub: Subscriber[? >: B]): Subscriber[? >: B] = new Subscriber[B] {
private val logger = LoggerFactory.getLogger(getClass.getName)
override def onSubscribe(s: Subscription): Unit = sub.onSubscribe(s)
override def onNext(t: B): Unit = sub.onNext(t)
override def onError(t: Throwable): Unit =
Expand All @@ -64,16 +74,33 @@ private[sync] object OxSourceWebSocketProcessor:
sub.onComplete()
}
new OxProcessor(oxDispatcher, frame2FramePipe, wrapSubscriberWithNettyCallback)
end apply

private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] =
if doConcatenate then s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
else s

private def takeUntilCloseFrame(passAlongCloseFrame: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] =
private def takeUntilCloseFrame(passAlongCloseFrame: Boolean, closeSignal: Semaphore)(
s: Source[WebSocketFrame]
)(using Ox): Source[WebSocketFrame] =
s.takeWhile(
{
case _: WebSocketFrame.Close => false
case _: WebSocketFrame.Close => closeSignal.release(); false
case _ => true
},
includeFirstFailing = passAlongCloseFrame
)

private def monitorOutgoingClosedAfterClientClose(closeSignal: Semaphore, outgoing: Source[_])(using Ox): Unit =
// will be interrupted when outgoing is completed
fork {
closeSignal.acquire()
sleep(outgoingCloseAfterCloseTimeout)
if !outgoing.isClosedForReceive then
logger.error(
s"WebSocket outgoing messages channel either not drained, or not closed, " +
s"$outgoingCloseAfterCloseTimeout after receiving a close frame from the client! " +
s"Make sure to complete the outgoing channel in your pipeline, once the incoming " +
s"channel is done!"
)
}.discard
Loading