Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski committed Mar 23, 2024
1 parent 00cada7 commit fc483b1
Show file tree
Hide file tree
Showing 24 changed files with 582 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package sttp.tapir.examples

import cats.effect.{IO, IOApp}
import sttp.client3._
import sttp.model.StatusCode
import sttp.tapir.server.netty.cats.NettyCatsServer
import sttp.tapir.*
import scala.concurrent.duration._
import sttp.capabilities.fs2.Fs2Streams
import sttp.ws.WebSocket
import sttp.client3.pekkohttp.PekkoHttpBackend
import scala.concurrent.Future

object WebSocketsNettyCatsServer extends IOApp.Simple {
// One endpoint on GET /hello with query parameter `name`
val helloWorldEndpoint: PublicEndpoint[String, Unit, String, Any] =
endpoint.get.in("hello").in(query[String]("name")).out(stringBody)

val wsEndpoint =
endpoint.get.in("ws").out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](Fs2Streams[IO]))

val wsServerEndpoint = wsEndpoint.serverLogicSuccess(_ =>
IO.pure(in => in.evalMap(str => IO.println(s"responding with ${str.toUpperCase}") >> IO.pure(str.toUpperCase())))
)
// Just returning passed name with `Hello, ` prepended
val helloWorldServerEndpoint = helloWorldEndpoint
.serverLogic(name => IO.pure[Either[Unit, String]](Right(s"Hello, $name!")))

private val declaredPort = 9090
private val declaredHost = "localhost"

// Creating handler for netty bootstrap
override def run = NettyCatsServer
.io()
.use { server =>
for {
binding <- server
.port(declaredPort)
.host(declaredHost)
.addEndpoints(List(wsServerEndpoint, helloWorldServerEndpoint))
.start()
result <- IO
.fromFuture(IO.delay {
val port = binding.port
val host = binding.hostName
println(s"Server started at port = ${binding.port}")
import scala.concurrent.ExecutionContext.Implicits.global
def useWebSocket(ws: WebSocket[Future]): Future[Unit] = {
def send(i: Int) = ws.sendText(s"Hello $i!")
def receive() = ws.receiveText().map(t => println(s"Client RECEIVED: $t"))
for {
_ <- send(1)
_ <- receive()
_ <- send(2)
_ <- send(3)
_ <- receive()
} yield ()
}
val backend = PekkoHttpBackend()

val url = uri"ws://$host:$port/ws"
val allGood = uri"http://$host:$port/hello?name=Netty"
basicRequest.response(asStringAlways).get(allGood).send(backend).map(r => println(r.body))
.flatMap { _ =>
basicRequest
.response(asWebSocket(useWebSocket))
.get(url)
.send(backend)
}
.andThen { case _ => backend.close() }
})
.guarantee(binding.stop())
} yield result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,36 @@ package sttp.tapir.perf.netty.cats
import cats.effect.IO
import cats.effect.kernel.Resource
import cats.effect.std.Dispatcher
import fs2.Stream
import sttp.tapir.{CodecFormat, webSocketBody}
import sttp.tapir.perf.Common._
import sttp.tapir.perf.apis._
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.cats.NettyCatsServer
import sttp.tapir.server.netty.cats.NettyCatsServerOptions
import sttp.ws.WebSocketFrame
import sttp.capabilities.fs2.Fs2Streams

object Tapir extends Endpoints
import scala.concurrent.duration._

object NettyCats {
object Tapir extends Endpoints {
val wsResponseStream = Stream.fixedRate[IO](WebSocketSingleResponseLag, dampen = false)
val wsEndpoint = wsBaseEndpoint
.out(
webSocketBody[Long, CodecFormat.TextPlain, Long, CodecFormat.TextPlain](Fs2Streams[IO])
.concatenateFragmentedFrames(false)
.autoPongOnPing(false)
.ignorePong(true)
.autoPing(None)
)
}

object NettyCats {
val wsServerEndpoint = Tapir.wsEndpoint.serverLogicSuccess(_ =>
IO.pure { (in: Stream[IO, Long]) =>
Tapir.wsResponseStream.evalMap(_ => IO.realTime.map(_.toMillis)).concurrently(in.as(()))
}
)
def runServer(endpoints: List[ServerEndpoint[Any, IO]], withServerLog: Boolean = false): IO[ServerRunner.KillSwitch] = {
val declaredPort = Port
val declaredHost = "0.0.0.0"
Expand All @@ -25,7 +45,7 @@ object NettyCats {
server
.port(declaredPort)
.host(declaredHost)
.addEndpoints(endpoints)
.addEndpoints(wsServerEndpoint :: endpoints)
.start()
)(binding => binding.stop())
} yield ()).allocated.map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.Future
import scala.concurrent.duration._
import sttp.capabilities.WebSockets
import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler

case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) {
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se))
def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] =
addEndpoints(List(se), overrideOptions)
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute(
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): NettyCatsServer[F] = addRoute(
NettyCatsServerInterpreter(options).toRoute(ses)
)
def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute(
Expand Down Expand Up @@ -75,6 +77,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty
NettyBootstrap(
config,
new NettyServerHandler(route, unsafeRunAsync, channelGroup, isShuttingDown, config.serverHeader),
new ReactiveWebSocketHandler(route, channelGroup, unsafeRunAsync, config.sslContext.isDefined),
eventLoopGroup,
socketOverride
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, Serve
import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _}
import sttp.tapir.server.netty.cats.internal.NettyCatsRequestBody
import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route}
import sttp.capabilities.WebSockets

trait NettyCatsServerInterpreter[F[_]] {
implicit def async: Async[F]
def nettyServerOptions: NettyCatsServerOptions[F]

def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = {
def toRoute(ses: List[ServerEndpoint[Fs2Streams[F] with WebSockets, F]]): Route[F] = {

implicit val monad: MonadError[F] = new CatsMonadError[F]
val runAsync = new RunAsync[F] {
Expand All @@ -31,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] {
val createFile = nettyServerOptions.createFile
val deleteFile = nettyServerOptions.deleteFile

val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]](
val serverInterpreter = new ServerInterpreter[Fs2Streams[F] with WebSockets, F, NettyResponse, Fs2Streams[F]](
FilterServerEndpoints(ses),
new NettyCatsRequestBody(createFile, Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
new NettyToStreamsResponseBody(Fs2StreamCompatible[F](nettyServerOptions.dispatcher)),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
package sttp.tapir.server.netty.cats.internal

import cats.effect.kernel.{Async, Sync}
import cats.effect.std.Dispatcher
import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher}
import fs2.io.file.{Files, Flags, Path}
import fs2.{Chunk, Pipe}
import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.websocketx.WebSocketFrame
import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent}
import org.reactivestreams.Publisher
import sttp.tapir.FileRange
import org.reactivestreams.{Processor, Publisher}
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.server.netty.internal._
import sttp.tapir.{FileRange, WebSocketBodyOutput}

import java.io.InputStream
import cats.effect.std.Dispatcher
import sttp.capabilities.fs2.Fs2Streams
import fs2.io.file.Path
import fs2.io.file.Files
import cats.effect.kernel.Async
import fs2.io.file.Flags
import fs2.interop.reactivestreams.StreamUnicastPublisher
import cats.effect.kernel.Sync
import fs2.Chunk
import fs2.interop.reactivestreams.StreamSubscriber

object Fs2StreamCompatible {

Expand Down Expand Up @@ -68,6 +65,12 @@ object Fs2StreamCompatible {
override def emptyStream: streams.BinaryStream =
fs2.Stream.empty

override def asWsProcessor[REQ, RESP](
pipe: Pipe[F, REQ, RESP],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]]
): Processor[WebSocketFrame, WebSocketFrame] =
new WebSocketPipeProcessor[F, REQ, RESP](pipe, dispatcher, o)

private def inputStreamToFs2(inputStream: () => InputStream, chunkSize: Int) =
fs2.io.readInputStream(
Sync[F].blocking(inputStream()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import sttp.tapir.TapirFile
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.capabilities.WebSockets

private[cats] class NettyCatsRequestBody[F[_]: Async](
val createFile: ServerRequest => F[TapirFile],
Expand All @@ -24,7 +25,8 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] =
(toStream(serverRequest, maxBytes).asInstanceOf[streamCompatible.streams.BinaryStream])
(toStream(serverRequest, maxBytes)
.asInstanceOf[streamCompatible.streams.BinaryStream])
.through(
Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(file.toPath))
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sttp.tapir.server.netty.cats.internal

import cats.Applicative
import cats.effect.kernel.Async
import cats.effect.std.Dispatcher
import fs2.interop.reactivestreams.{StreamSubscriber, StreamUnicastPublisher}
import fs2.{Pipe, Stream}
import io.netty.handler.codec.http.websocketx.{WebSocketFrame => NettyWebSocketFrame}
import org.reactivestreams.{Processor, Publisher, Subscriber, Subscription}
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.server.netty.internal.WebSocketFrameConverters._
import sttp.tapir.{DecodeResult, WebSocketBodyOutput}
import sttp.ws.WebSocketFrame

import scala.concurrent.ExecutionContext.Implicits
import scala.concurrent.Promise
import scala.util.{Failure, Success}

class WebSocketPipeProcessor[F[_]: Async, REQ, RESP](
pipe: Pipe[F, REQ, RESP],
dispatcher: Dispatcher[F],
o: WebSocketBodyOutput[Pipe[F, REQ, RESP], REQ, RESP, ?, Fs2Streams[F]]
) extends Processor[NettyWebSocketFrame, NettyWebSocketFrame] {
private var subscriber: StreamSubscriber[F, NettyWebSocketFrame] = _
private val publisher: Promise[Publisher[NettyWebSocketFrame]] = Promise[Publisher[NettyWebSocketFrame]]()
private var subscription: Subscription = _

override def onSubscribe(s: Subscription): Unit = {
subscriber = dispatcher.unsafeRunSync(
StreamSubscriber[F, NettyWebSocketFrame](bufferSize = 1)
)
subscription = s
val in: Stream[F, NettyWebSocketFrame] = subscriber.sub.stream(Applicative[F].unit)
val sttpFrames = in.map { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
val stream: Stream[F, NettyWebSocketFrame] =
optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
.map(f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure)
}
)
.through(pipe)
.map(r => frameToNettyFrame(o.responses.encode(r)))
.append(fs2.Stream(frameToNettyFrame(WebSocketFrame.close)))

subscriber.sub.onSubscribe(s)
publisher.success(StreamUnicastPublisher(stream, dispatcher))
}

override def onNext(t: NettyWebSocketFrame): Unit = {
subscriber.sub.onNext(t)
}

override def onError(t: Throwable): Unit = {
subscriber.sub.onError(t)
}

override def onComplete(): Unit = {
subscriber.sub.onComplete()
}

override def subscribe(s: Subscriber[_ >: NettyWebSocketFrame]): Unit = {
publisher.future.onComplete {
case Success(p) =>
p.subscribe(s)
case Failure(ex) =>
subscriber.sub.onError(ex)
subscription.cancel
}(Implicits.global)
}

private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

s.mapAccumulate(None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collect { case (_, Some(f)) => f }
} else s
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) ++
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests()
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++
new ServerWebSocketTests(createServerTest, Fs2Streams[IO]) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty
}.tests()

IO.pure((tests, eventLoopGroup))
} { case (_, eventLoopGroup) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cats.data.NonEmptyList
import cats.effect.std.Dispatcher
import cats.effect.{IO, Resource}
import io.netty.channel.nio.NioEventLoopGroup
import sttp.capabilities.WebSockets
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.{NettyConfig, Route}
import sttp.tapir.server.tests.TestServerInterpreter
Expand All @@ -12,8 +13,8 @@ import sttp.capabilities.fs2.Fs2Streams
import scala.concurrent.duration.FiniteDuration

class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO])
extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] {
override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = {
extends TestServerInterpreter[IO, Fs2Streams[IO] with WebSockets, NettyCatsServerOptions[IO], Route[IO]] {
override def route(es: List[ServerEndpoint[Fs2Streams[IO] with WebSockets, IO]], interceptors: Interceptors): Route[IO] = {
val serverOptions: NettyCatsServerOptions[IO] = interceptors(
NettyCatsServerOptions.customiseInterceptors[IO](dispatcher)
).options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.concurrent.Future
import scala.concurrent.Promise
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal
import sttp.tapir.server.netty.internal.ReactiveWebSocketHandler

case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions, config: NettyConfig) {
private val executor = Executors.newVirtualThreadPerTaskExecutor()
Expand Down Expand Up @@ -100,6 +101,7 @@ case class NettyIdServer(routes: Vector[IdRoute], options: NettyIdServerOptions,
isShuttingDown,
config.serverHeader
),
new ReactiveWebSocketHandler(route, channelGroup, unsafeRunF, config.sslContext.isDefined),
eventLoopGroup,
socketOverride
)
Expand Down
Loading

0 comments on commit fc483b1

Please sign in to comment.