Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket support for vertx servers #2770

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import cats.syntax.all._
import io.vertx.core.logging.LoggerFactory
import io.vertx.core.{Future, Handler}
import io.vertx.ext.web.{Route, Router, RoutingContext}
import sttp.capabilities.Streams
import sttp.capabilities.{Streams, WebSockets}
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
import sttp.tapir.server.ServerEndpoint
Expand Down Expand Up @@ -36,21 +36,21 @@ trait VertxCatsServerInterpreter[F[_]] extends CommonServerInterpreter {
* A function, that given a router, will attach this endpoint to it
*/
def route(
e: ServerEndpoint[Fs2Streams[F], F]
e: ServerEndpoint[Fs2Streams[F] with WebSockets, F]
): Router => Route = { router =>
val readStreamCompatible = fs2ReadStreamCompatible(vertxCatsServerOptions)
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint)).handler(endpointHandler(e, readStreamCompatible))
}

private def endpointHandler[S <: Streams[S]](
e: ServerEndpoint[Fs2Streams[F], F],
e: ServerEndpoint[Fs2Streams[F] with WebSockets, F],
readStreamCompatible: ReadStreamCompatible[S]
): Handler[RoutingContext] = {
implicit val monad: MonadError[F] = monadError[F]
implicit val bodyListener: BodyListener[F, RoutingContext => Future[Void]] =
new VertxBodyListener[F](new CatsRunAsync(vertxCatsServerOptions.dispatcher))
val fFromVFuture = new CatsFFromVFuture[F]
val interpreter: ServerInterpreter[Fs2Streams[F], F, RoutingContext => Future[Void], S] = new ServerInterpreter(
val interpreter: ServerInterpreter[Fs2Streams[F] with WebSockets, F, RoutingContext => Future[Void], S] = new ServerInterpreter(
_ => List(e),
new VertxRequestBody(vertxCatsServerOptions, fFromVFuture)(readStreamCompatible),
new VertxToResponseBody(vertxCatsServerOptions)(readStreamCompatible),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ import io.vertx.core.Handler
import io.vertx.core.buffer.Buffer
import io.vertx.core.streams.ReadStream
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.server.vertx.cats.VertxCatsServerOptions
import sttp.tapir.server.vertx.streams.ReadStreamState.{WrappedBuffer, WrappedEvent}
import sttp.tapir.server.vertx.streams.websocket._
import sttp.tapir.server.vertx.streams._
import sttp.ws.WebSocketFrame

import scala.collection.immutable.{Queue => SQueue}

Expand All @@ -29,15 +33,18 @@ object fs2 {
new ReadStreamCompatible[Fs2Streams[F]] {
override val streams: Fs2Streams[F] = Fs2Streams[F]

override def asReadStream(stream: Stream[F, Byte]): ReadStream[Buffer] = {
override def asReadStream(stream: Stream[F, Byte]): ReadStream[Buffer] =
mapToReadStream[Chunk[Byte], Buffer](stream.chunks, chunk => Buffer.buffer(chunk.toArray))

private def mapToReadStream[I, O](stream: Stream[F, I], fn: I => O): ReadStream[O] =
opts.dispatcher.unsafeRunSync {
for {
promise <- Deferred[F, Unit]
state <- Ref.of(StreamState.empty[F](promise))
state <- Ref.of(StreamState.empty[F, O](promise))
_ <- F.start(
stream.chunks
stream
.evalMap({ chunk =>
val buffer = Buffer.buffer(chunk.toArray)
val buffer = fn(chunk)
state.get.flatMap {
case StreamState(None, handler, _, _) =>
F.delay(handler.handle(buffer))
Expand Down Expand Up @@ -67,18 +74,18 @@ object fs2 {
.compile
.drain
)
} yield new ReadStream[Buffer] {
} yield new ReadStream[O] {
self =>
override def handler(handler: Handler[Buffer]): ReadStream[Buffer] =
override def handler(handler: Handler[O]): ReadStream[O] =
opts.dispatcher.unsafeRunSync(state.update(_.copy(handler = handler)).as(self))

override def endHandler(handler: Handler[Void]): ReadStream[Buffer] =
override def endHandler(handler: Handler[Void]): ReadStream[O] =
opts.dispatcher.unsafeRunSync(state.update(_.copy(endHandler = handler)).as(self))

override def exceptionHandler(handler: Handler[Throwable]): ReadStream[Buffer] =
override def exceptionHandler(handler: Handler[Throwable]): ReadStream[O] =
opts.dispatcher.unsafeRunSync(state.update(_.copy(errorHandler = handler)).as(self))

override def pause(): ReadStream[Buffer] =
override def pause(): ReadStream[O] =
opts.dispatcher.unsafeRunSync(for {
deferred <- Deferred[F, Unit]
_ <- state.update {
Expand All @@ -89,25 +96,27 @@ object fs2 {
}
} yield self)

override def resume(): ReadStream[Buffer] =
override def resume(): ReadStream[O] =
opts.dispatcher.unsafeRunSync(for {
oldState <- state.getAndUpdate(_.copy(paused = None))
_ <- oldState.paused.fold(Async[F].unit)(_.complete(()))
} yield self)

override def fetch(n: Long): ReadStream[Buffer] =
override def fetch(n: Long): ReadStream[O] =
self
}
}
}

override def fromReadStream(readStream: ReadStream[Buffer]): Stream[F, Byte] =
fromReadStreamInternal(readStream).map(buffer => Chunk.array(buffer.getBytes)).unchunks

private def fromReadStreamInternal[T](readStream: ReadStream[T]): Stream[F, T] =
opts.dispatcher.unsafeRunSync {
for {
stateRef <- Ref.of(ReadStreamState[F, Chunk[Byte]](Queued(SQueue.empty), Queued(SQueue.empty)))
stream = Stream.unfoldChunkEval[F, Unit, Byte](()) { _ =>
stateRef <- Ref.of(ReadStreamState[F, T](Queued(SQueue.empty), Queued(SQueue.empty)))
stream = Stream.unfoldEval[F, Unit, T](()) { _ =>
for {
dfd <- Deferred[F, WrappedBuffer[Chunk[Byte]]]
dfd <- Deferred[F, WrappedBuffer[T]]
tuple <- stateRef.modify(_.dequeueBuffer(dfd))
(mbBuffer, mbAction) = tuple
_ <- mbAction.traverse(identity)
Expand Down Expand Up @@ -154,14 +163,55 @@ object fs2 {
opts.dispatcher.unsafeRunSync(stateRef.modify(_.halt(Some(cause))).flatMap(_.sequence_))
}
readStream.handler { buffer =>
val chunk = Chunk.array(buffer.getBytes)
val maxSize = opts.maxQueueSizeForReadStream
opts.dispatcher.unsafeRunSync(stateRef.modify(_.enqueue(chunk, maxSize)).flatMap(_.sequence_))
opts.dispatcher.unsafeRunSync(stateRef.modify(_.enqueue(buffer, maxSize)).flatMap(_.sequence_))
}

stream
}
}

override def webSocketPipe[REQ, RESP](
readStream: ReadStream[WebSocketFrame],
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]]
): ReadStream[WebSocketFrame] = {
val stream0 = fromReadStreamInternal(readStream)
val stream1 = optionallyContatenateFrames(stream0, o.concatenateFragmentedFrames)
val stream2 = optionallyIgnorePong(stream1, o.ignorePong)
val autoPings = o.autoPing match {
case Some((interval, frame)) =>
Stream.awakeEvery(interval).as(frame)
case None =>
Stream.empty
}

val stream3 = stream2
.map { frame =>
o.requests.decode(frame) match {
case DecodeResult.Value(v) =>
v
case failure: DecodeResult.Failure =>
throw new WebSocketFrameDecodeFailure(frame, failure)
}
}
.through(pipe)
.map(o.responses.encode)
.mergeHaltL(autoPings)
.append(Stream(WebSocketFrame.close))

mapToReadStream[WebSocketFrame, WebSocketFrame](stream3, identity)
}

def optionallyContatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
s.mapAccumulate(None: Accumulator)(concatenateFrames).collect { case (_, Some(f)) => f }
} else {
s
}

def optionallyIgnorePong(s: Stream[F, WebSocketFrame], ignore: Boolean): Stream[F, WebSocketFrame] =
if (ignore) s.filterNot(isPong) else s
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sttp.tapir.server.vertx.cats

import cats.effect.{IO, Resource}
import fs2.Stream
import io.vertx.core.Vertx
import sttp.capabilities.fs2.Fs2Streams
import sttp.monad.MonadError
Expand All @@ -25,7 +26,11 @@ class CatsVertxServerTest extends TestSuite {
partContentTypeHeaderSupport = false, // README: doesn't seem supported but I may be wrong
partOtherHeaderSupport = false
).tests() ++
new ServerStreamingTests(createServerTest, Fs2Streams.apply[IO]).tests()
new ServerStreamingTests(createServerTest, Fs2Streams.apply[IO]).tests() ++
new ServerWebSocketTests(createServerTest, Fs2Streams.apply[IO]) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty
}.tests()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ import io.vertx.core.Vertx
import io.vertx.core.http.HttpServerOptions
import io.vertx.ext.web.{Route, Router}
import sttp.capabilities.fs2.Fs2Streams
import sttp.capabilities.WebSockets
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.tests.TestServerInterpreter
import sttp.tapir.server.vertx.cats.VertxCatsServerInterpreter.CatsFFromVFuture
import sttp.tapir.tests.Port

class CatsVertxTestServerInterpreter(vertx: Vertx, dispatcher: Dispatcher[IO])
extends TestServerInterpreter[IO, Fs2Streams[IO], VertxCatsServerOptions[IO], Router => Route] {
extends TestServerInterpreter[IO, Fs2Streams[IO] with WebSockets, VertxCatsServerOptions[IO], Router => Route] {

private val ioFromVFuture = new CatsFFromVFuture[IO]

override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Router => Route = { router =>
override def route(es: List[ServerEndpoint[Fs2Streams[IO] with WebSockets, IO]], interceptors: Interceptors): Router => Route = { router =>
val options: VertxCatsServerOptions[IO] = interceptors(VertxCatsServerOptions.customiseInterceptors[IO](dispatcher)).options
val interpreter = VertxCatsServerInterpreter(options)
es.map(interpreter.route(_)(router)).last
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@ class VertxBodyListener[F[_]](runAsync: RunAsync[F])(implicit m: MonadError[F])
override def onComplete(body: RoutingContext => Future[Void])(cb: Try[Unit] => F[Unit]): F[RoutingContext => Future[Void]] = {
m.unit {
{ (ctx: RoutingContext) =>
body {
ctx.addBodyEndHandler(_ => runAsync(cb(Success(()))))
ctx.addEndHandler(res => if (res.failed()) runAsync(cb(Failure(res.cause()))))
ctx
// Unfortunately I can not find more reliable way to define that this request is actually websocket.
// When this code is called server response is not yet written.
if (ctx.request().getHeader("Upgrade") == "websocket") {
Future.succeededFuture(runAsync(cb(Success(())))).flatMap(_ => body(ctx))
} else {
body {
ctx.addBodyEndHandler(_ => runAsync(cb(Success(()))))
ctx.addEndHandler(res => if (res.failed()) runAsync(cb(Failure(res.cause()))))
ctx
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class VertxRequestBody[F[_], S <: Streams[S]](

override def toStream(serverRequest: ServerRequest): streams.BinaryStream =
readStreamCompatible
.fromReadStream(routingContext(serverRequest).request.asInstanceOf[ReadStream[Buffer]])
.fromReadStream(routingContext(serverRequest).request)
.asInstanceOf[streams.BinaryStream]

private def extractStringPart[B](part: String, bodyType: RawBodyType[B]): Option[Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sttp.tapir.server.vertx.encoders

import io.vertx.core.Future
import io.vertx.core.buffer.Buffer
import io.vertx.core.http.{HttpHeaders, HttpServerResponse}
import io.vertx.core.http.{HttpHeaders, HttpServerResponse, ServerWebSocket}
import io.vertx.ext.web.RoutingContext
import sttp.capabilities.Streams
import sttp.model.{HasHeaders, Part}
Expand Down Expand Up @@ -52,7 +52,19 @@ class VertxToResponseBody[F[_], S <: Streams[S]](serverOptions: VertxServerOptio
override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S]
): RoutingContext => Future[Void] = throw new UnsupportedOperationException()
): RoutingContext => Future[Void] = { rc =>
rc.request
.toWebSocket
.flatMap({ (websocket: ServerWebSocket) =>
Pipe(readStreamCompatible.webSocketPipe[REQ, RESP](
wrapWebSocket(websocket),
pipe.asInstanceOf[readStreamCompatible.streams.Pipe[REQ, RESP]],
o.asInstanceOf[WebSocketBodyOutput[readStreamCompatible.streams.Pipe[REQ, RESP], REQ, RESP, _, S]]
), websocket)
websocket.accept()
Future.succeededFuture[Void]()
})
}

private def handleMultipleBodyParts[CF <: CodecFormat, R](
multipart: RawBodyType[R] with RawBodyType.MultipartBody,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package sttp.tapir.server.vertx

import java.io.{ByteArrayInputStream, InputStream}

import io.vertx.core.buffer.Buffer
import io.vertx.core.Future
import io.vertx.core.Handler
import io.vertx.core.Vertx
import io.vertx.core.buffer.Buffer
import io.vertx.core.http.{ServerWebSocket, WebSocketFrameType}
import io.vertx.core.streams.ReadStream
import sttp.ws.WebSocketFrame

package object encoders {

Expand Down Expand Up @@ -32,4 +36,46 @@ package object encoders {
buffer
}

def wrapWebSocket(websocket: ServerWebSocket): ReadStream[WebSocketFrame] =
new ReadStream[WebSocketFrame] {
override def exceptionHandler(handler: Handler[Throwable]): ReadStream[WebSocketFrame] = {
websocket.exceptionHandler(handler)
this
}

override def handler(handler: Handler[WebSocketFrame]): ReadStream[WebSocketFrame] = {
websocket.frameHandler { frame =>
val t = frame.`type`()
if (t == WebSocketFrameType.TEXT) {
handler.handle(WebSocketFrame.Text(frame.textData(), frame.isFinal(), None))
} else if (t == WebSocketFrameType.BINARY) {
handler.handle(WebSocketFrame.Binary(frame.binaryData().getBytes(), frame.isFinal(), None))
}
}
websocket.pongHandler { buffer =>
handler.handle(WebSocketFrame.Pong(buffer.getBytes()))
}
this
}

override def pause(): ReadStream[WebSocketFrame] = {
websocket.pause()
this
}

override def resume(): ReadStream[WebSocketFrame] = {
websocket.resume()
this
}

override def fetch(amount: Long): ReadStream[WebSocketFrame] = {
websocket.fetch(amount)
this
}

override def endHandler(endHandler: Handler[Void]): ReadStream[WebSocketFrame] = {
websocket.endHandler(endHandler)
this
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.tapir.server.vertx
import io.vertx.core.Handler
import io.vertx.ext.web.{Route, RoutingContext}
import io.vertx.ext.web.handler.BodyHandler
import sttp.tapir.{Endpoint, EndpointIO}
import sttp.tapir.{Endpoint, EndpointIO, EndpointOutput}
import sttp.tapir.RawBodyType.MultipartBody
import sttp.tapir.internal._

Expand All @@ -22,16 +22,21 @@ package object handlers {
}

private[vertx] def attachDefaultHandlers[E](e: Endpoint[_, _, E, _, _], route: Route): Route = {
val mbWebsocketType = e.output.traverseOutputs[EndpointOutput.WebSocketBodyWrapper[_, _]] {
case body: EndpointOutput.WebSocketBodyWrapper[_, _] => Vector(body)
}

val bodyType = e.asVectorOfBasicInputs().flatMap {
case body: EndpointIO.Body[_, _] => Vector(body.bodyType)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(_.body.fold(body => Some(body.bodyType), _.bodyType)).toVector
case body: EndpointIO.StreamBodyWrapper[_, _] => Vector(body)
case _ => Vector.empty
}

bodyType.headOption match {
mbWebsocketType.headOption.orElse(bodyType.headOption) match {
case Some(MultipartBody(_, _)) => route.handler(multipartHandler)
case Some(_: EndpointIO.StreamBodyWrapper[_, _]) => route.handler(streamPauseHandler)
case Some(_: EndpointOutput.WebSocketBodyWrapper[_, _]) => route.handler(streamPauseHandler)
case Some(_) => route.handler(bodyHandler)
case None => ()
}
Expand Down
Loading