From 29a7ce5880b46be67e5cb6bf09e8b10fc0a5d40c Mon Sep 17 00:00:00 2001 From: Yehia AboSedira Date: Thu, 31 Aug 2023 19:23:03 +0200 Subject: [PATCH 1/4] Adding tests --- .../server/ziohttp/ZioHttpInterpreter.scala | 42 +++---- .../server/ziohttp/ZioHttpServerTest.scala | 115 ++++++++---------- .../ZioHttpTestServerInterpreter.scala | 5 +- 3 files changed, 72 insertions(+), 90 deletions(-) diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 0c9c4b14ad..6137cf100a 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.ziohttp +import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams import sttp.model.{Method, Header => SttpHeader} import sttp.monad.MonadError @@ -13,21 +14,21 @@ import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} trait ZioHttpInterpreter[R] { def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default - def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams]): HttpApp[R & R2, Throwable] = + def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams with WebSockets]): HttpApp[R & R2, Throwable] = toHttp(List(se)) - def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams]]): HttpApp[R & R2, Throwable] = { + def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = { implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2] - implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2] - val widenedSes = ses.map(_.widen[R & R2]) - val widenedServerOptions = zioHttpServerOptions.widen[R & R2] - val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions) - val zioHttpResponseBody = new ZioHttpToResponseBody - val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) + implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2] + val widenedSes = ses.map(_.widen[R & R2]) + val widenedServerOptions = zioHttpServerOptions.widen[R & R2] + val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions) + val zioHttpResponseBody = new ZioHttpToResponseBody + val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) - def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams]]) = { + def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) = Handler.fromZIO { - val interpreter = new ServerInterpreter[ZioStreams, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams]( + val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams]( _ => filteredEndpoints, zioHttpRequestBody, zioHttpResponseBody, @@ -42,12 +43,12 @@ trait ZioHttpInterpreter[R] { { case RequestResult.Response(resp) => val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList - val allHeaders = resp.body.flatMap(_.contentLength) match { + val allHeaders = resp.body.flatMap(_.contentLength) match { case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders - case _ => baseHeaders + case _ => baseHeaders } - val statusCode = resp.code.code + val statusCode = resp.code.code ZIO.succeed( Response( @@ -61,7 +62,7 @@ trait ZioHttpInterpreter[R] { .getOrElse(Body.empty) ) ) - case RequestResult.Failure(_) => + case RequestResult.Failure(_) => ZIO.fail( new RuntimeException( s"The path: ${req.path} matches the shape of some endpoint, but none of the " + @@ -73,14 +74,13 @@ trait ZioHttpInterpreter[R] { } ) } - } - val serverEndpointsFilter = FilterServerEndpoints[ZioStreams, RIO[R & R2, *]](widenedSes) - val singleEndpoint = widenedSes.size == 1 + val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes) + val singleEndpoint = widenedSes.size == 1 Http.fromOptionalHandlerZIO { request => // pre-filtering the endpoints by shape to determine, if this request should be handled by tapir - val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request)) + val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request)) val filteredEndpoints2 = if (singleEndpoint) { // If we are interpreting a single endpoint, we verify that the method matches as well; in case it doesn't, // we refuse to handle the request, allowing other ZIO Http routes to handle it. Otherwise even if the method @@ -103,14 +103,12 @@ trait ZioHttpInterpreter[R] { } object ZioHttpInterpreter { - def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = { + def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = new ZioHttpInterpreter[R] { override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions } - } - def apply(): ZioHttpInterpreter[Any] = { + def apply(): ZioHttpInterpreter[Any] = new ZioHttpInterpreter[Any] { override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any] } - } } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 881eb87f06..9327b7d86d 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -1,46 +1,24 @@ package sttp.tapir.server.ziohttp -import cats.effect.IO -import cats.effect.Resource -import io.netty.channel.ChannelFactory -import io.netty.channel.EventLoopGroup -import io.netty.channel.ServerChannel -import org.scalatest.Assertion -import org.scalatest.Exceptional -import org.scalatest.FutureOutcome +import cats.effect.{IO, Resource} +import io.netty.channel.{ChannelFactory, EventLoopGroup, ServerChannel} +import org.scalatest.{Assertion, Exceptional, FutureOutcome} import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.zio.ZioStreams import sttp.client3._ import sttp.client3.testing.SttpBackendStub import sttp.model.MediaType import sttp.monad.MonadError -import sttp.tapir.PublicEndpoint -import sttp.tapir._ +import sttp.tapir.{PublicEndpoint, _} import sttp.tapir.server.stub.TapirStubInterpreter import sttp.tapir.server.tests._ -import sttp.tapir.tests.Test -import sttp.tapir.tests.TestSuite -import sttp.tapir.ztapir.RIOMonadError -import sttp.tapir.ztapir.RichZEndpoint -import zio.Promise -import zio.Ref -import zio.Runtime -import zio.Task -import zio.UIO -import zio.Unsafe -import zio.ZEnvironment -import zio.ZIO -import zio.ZLayer -import zio.http.HttpAppMiddleware -import zio.http.Path -import zio.http.Request -import zio.http.URL -import zio.http.netty.ChannelFactories -import zio.http.netty.ChannelType -import zio.http.netty.EventLoopGroups +import sttp.tapir.tests.{Test, TestSuite} +import sttp.tapir.ztapir.{RIOMonadError, RichZEndpoint} +import zio.{Promise, Ref, Runtime, Task, UIO, Unsafe, ZEnvironment, ZIO, ZLayer} +import zio.http.{HttpAppMiddleware, Path, Request, URL} +import zio.http.netty.{ChannelFactories, ChannelType, EventLoopGroups} import zio.interop.catz._ -import zio.stream.ZPipeline -import zio.stream.ZStream +import zio.stream.{ZPipeline, ZStream} import java.nio.charset.Charset import java.time @@ -61,7 +39,7 @@ class ZioHttpServerTest extends TestSuite { println(s"Test ${test.name} failed, retrying.") e.printStackTrace() (if (count == 1) super.withFixture(test) else withFixture(test, count - 1)).toFuture - case other => Future.successful(other) + case other => Future.successful(other) }) } @@ -72,23 +50,23 @@ class ZioHttpServerTest extends TestSuite { .scoped[IO, Any, ZEnvironment[EventLoopGroup with ChannelFactory[ServerChannel]]]({ val eventConfig = ZLayer.succeed(new EventLoopGroups.Config { def channelType = ChannelType.AUTO - val nThreads = 0 + val nThreads = 0 }) val channelConfig: ZLayer[Any, Nothing, ChannelType.Config] = eventConfig (channelConfig >>> ChannelFactories.Server.fromConfig) ++ (eventConfig >>> EventLoopGroups.live) }.build) .map { nettyDeps => - val eventLoopGroup = ZLayer.succeed(nettyDeps.get[EventLoopGroup]) - val channelFactory = ZLayer.succeed(nettyDeps.get[ChannelFactory[ServerChannel]]) - val interpreter = new ZioHttpTestServerInterpreter(eventLoopGroup, channelFactory) + val eventLoopGroup = ZLayer.succeed(nettyDeps.get[EventLoopGroup]) + val channelFactory = ZLayer.succeed(nettyDeps.get[ChannelFactory[ServerChannel]]) + val interpreter = new ZioHttpTestServerInterpreter(eventLoopGroup, channelFactory) val createServerTest = new DefaultCreateServerTest(backend, interpreter) def additionalTests(): List[Test] = List( // https://github.com/softwaremill/tapir/issues/1914 Test("zio http route can be called with runZIO") { - val ep = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ZIO.succeed("response")) - val route = ZioHttpInterpreter().toHttp(ep) + val ep = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ZIO.succeed("response")) + val route = ZioHttpInterpreter().toHttp(ep) val test: UIO[Assertion] = route .runZIO(Request.get(url = URL.apply(Path.empty / "p1"))) .flatMap(response => response.body.asString) @@ -98,36 +76,36 @@ class ZioHttpServerTest extends TestSuite { }, Test("zio http middlewares run before the handler") { val test: UIO[Assertion] = for { - p <- Promise.make[Nothing, Unit] - ep = endpoint.get - .in("p1") - .out(stringBody) - .zServerLogic[Any](_ => p.await.timeout(time.Duration.ofSeconds(1)) *> ZIO.succeed("Ok")) - int = ZioHttpInterpreter().toHttp(ep) - route = int @@ HttpAppMiddleware.allowZIO((_: Request) => p.succeed(()).as(true)) + p <- Promise.make[Nothing, Unit] + ep = endpoint.get + .in("p1") + .out(stringBody) + .zServerLogic[Any](_ => p.await.timeout(time.Duration.ofSeconds(1)) *> ZIO.succeed("Ok")) + int = ZioHttpInterpreter().toHttp(ep) + route = int @@ HttpAppMiddleware.allowZIO((_: Request) => p.succeed(()).as(true)) result <- route - .runZIO(Request.get(url = URL(Path.empty / "p1"))) - .flatMap(response => response.body.asString) - .map(_ shouldBe "Ok") - .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + .runZIO(Request.get(url = URL(Path.empty / "p1"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "Ok") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) } yield result Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) }, Test("zio http middlewares only run once") { val test: UIO[Assertion] = for { - ref <- Ref.make(0) - ep = endpoint.get - .in("p1") - .out(stringBody) - .zServerLogic[Any](_ => ref.updateAndGet(_ + 1).map(_.toString)) - route = ZioHttpInterpreter() - .toHttp(ep) @@ HttpAppMiddleware.allowZIO((_: Request) => ref.update(_ + 1).as(true)) + ref <- Ref.make(0) + ep = endpoint.get + .in("p1") + .out(stringBody) + .zServerLogic[Any](_ => ref.updateAndGet(_ + 1).map(_.toString)) + route = ZioHttpInterpreter() + .toHttp(ep) @@ HttpAppMiddleware.allowZIO((_: Request) => ref.update(_ + 1).as(true)) result <- route - .runZIO(Request.get(url = URL(Path.empty / "p1"))) - .flatMap(response => response.body.asString) - .map(_ shouldBe "2") - .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + .runZIO(Request.get(url = URL(Path.empty / "p1"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "2") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) } yield result Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) @@ -142,7 +120,8 @@ class ZioHttpServerTest extends TestSuite { val backendStub: TapirStubInterpreter[Task, ZioStreams, Unit] = TapirStubInterpreter[Task, ZioStreams](SttpBackendStub[Task, ZioStreams](new RIOMonadError[Any])) - val endpointModel: PublicEndpoint[ZStream[Any, Throwable, Byte], Unit, ZStream[Any, Throwable, Byte], ZioStreams] = + val endpointModel + : PublicEndpoint[ZStream[Any, Throwable, Byte], Unit, ZStream[Any, Throwable, Byte], ZioStreams] = endpoint.post .in("hello") .in(streamBinaryBody(ZioStreams)(CsvCodecFormat)) @@ -151,16 +130,16 @@ class ZioHttpServerTest extends TestSuite { val streamingEndpoint: sttp.tapir.ztapir.ZServerEndpoint[Any, ZioStreams] = endpointModel .zServerLogic(stream => - ZIO.succeed({ + ZIO.succeed { stream .via(ZPipeline.utf8Decode) .via(ZPipeline.splitLines) .via(ZPipeline.intersperse(java.lang.System.lineSeparator())) .via(ZPipeline.utf8Encode) - }) + } ) - val inputStrings = List("Hello,how,are,you", "I,am,good,thanks") - val input: ZStream[Any, Nothing, Byte] = + val inputStrings = List("Hello,how,are,you", "I,am,good,thanks") + val input: ZStream[Any, Nothing, Byte] = ZStream(inputStrings: _*) .via(ZPipeline.intersperse(java.lang.System.lineSeparator())) .mapConcat(_.getBytes(Charset.forName("UTF-8"))) @@ -217,6 +196,10 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ new ZioHttpCompositionTest(createServerTest).tests() ++ + new ServerWebSocketTests(createServerTest, ZioStreams) { + override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) + override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty + }.tests() ++ additionalTests() } } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala index bb2e5a2bbd..e8b2c8aa01 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala @@ -3,6 +3,7 @@ package sttp.tapir.server.ziohttp import cats.data.NonEmptyList import cats.effect.{IO, Resource} import io.netty.channel.{ChannelFactory, EventLoopGroup, ServerChannel} +import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.tests.TestServerInterpreter @@ -16,9 +17,9 @@ class ZioHttpTestServerInterpreter( channelFactory: ZLayer[Any, Nothing, ChannelFactory[ServerChannel]] )(implicit trace: Trace -) extends TestServerInterpreter[Task, ZioStreams, ZioHttpServerOptions[Any], Http[Any, Throwable, Request, Response]] { +) extends TestServerInterpreter[Task, ZioStreams with WebSockets, ZioHttpServerOptions[Any], Http[Any, Throwable, Request, Response]] { - override def route(es: List[ServerEndpoint[ZioStreams, Task]], interceptors: Interceptors): Http[Any, Throwable, Request, Response] = { + override def route(es: List[ServerEndpoint[ZioStreams with WebSockets, Task]], interceptors: Interceptors): Http[Any, Throwable, Request, Response] = { val serverOptions: ZioHttpServerOptions[Any] = interceptors(ZioHttpServerOptions.customiseInterceptors).options ZioHttpInterpreter(serverOptions).toHttp(es) } From 10c02966e40ea72e703125786b78d9d6fada2edc Mon Sep 17 00:00:00 2001 From: Yehia AboSedira Date: Fri, 1 Sep 2023 01:51:47 +0200 Subject: [PATCH 2/4] Initial implementation of zio-http web-sockets support --- .../server/ziohttp/ZioHttpBodyListener.scala | 11 +-- .../server/ziohttp/ZioHttpInterpreter.scala | 85 ++++++++++++------- .../ziohttp/ZioHttpToResponseBody.scala | 20 +++-- .../tapir/server/ziohttp/ZioWebSockets.scala | 54 ++++++++++++ .../sttp/tapir/server/ziohttp/package.scala | 11 +++ .../server/ziohttp/ZioHttpServerTest.scala | 67 +++++++-------- .../ZioHttpTestServerInterpreter.scala | 5 +- 7 files changed, 176 insertions(+), 77 deletions(-) create mode 100644 server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala create mode 100644 server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala index 1c690a01a2..3289e5cc23 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala @@ -6,14 +6,14 @@ import zio.stream.ZStream import scala.util.{Failure, Success, Try} -private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] { - override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] = +private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] { + override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] = ZIO .environmentWithZIO[R] .apply { r => body match { - case ZioStreamHttpResponseBody(stream, contentLength) => - ZIO.succeed( + case Right(ZioStreamHttpResponseBody(stream, contentLength)) => + ZIO.right( ZioStreamHttpResponseBody( stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream .fromZIO(cb(Success(()))) @@ -22,7 +22,8 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi contentLength ) ) - case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw) + case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw) + case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws) } } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 6137cf100a..6def95d57e 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -7,8 +7,10 @@ import sttp.monad.MonadError import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.model.ServerResponse import sttp.tapir.ztapir._ import zio._ +import zio.http.ChannelEvent.Read import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} trait ZioHttpInterpreter[R] { @@ -19,16 +21,16 @@ trait ZioHttpInterpreter[R] { def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = { implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2] - implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2] - val widenedSes = ses.map(_.widen[R & R2]) - val widenedServerOptions = zioHttpServerOptions.widen[R & R2] - val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions) - val zioHttpResponseBody = new ZioHttpToResponseBody - val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) + implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2] + val widenedSes = ses.map(_.widen[R & R2]) + val widenedServerOptions = zioHttpServerOptions.widen[R & R2] + val zioHttpRequestBody = new ZioHttpRequestBody(widenedServerOptions) + val zioHttpResponseBody = new ZioHttpToResponseBody + val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) = Handler.fromZIO { - val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams]( + val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams]( _ => filteredEndpoints, zioHttpRequestBody, zioHttpResponseBody, @@ -42,27 +44,27 @@ trait ZioHttpInterpreter[R] { error => ZIO.fail(error), { case RequestResult.Response(resp) => - val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList - val allHeaders = resp.body.flatMap(_.contentLength) match { - case Some(contentLength) if resp.contentLength.isEmpty => - ZioHttpHeader.ContentLength(contentLength) :: baseHeaders - case _ => baseHeaders - } - val statusCode = resp.code.code + resp.body match { + case None => handleHttpResponse(resp, None) + case Some(Right(body)) => handleHttpResponse(resp, Some(body)) + case Some(Left(body)) => + Handler.webSocket { channel => + { + channel.receiveAll { + case ChannelEvent.Read(message) => + for { + m <- body(message) + _ <- ZIO.foldLeft(m)(())((_, z) => channel.send(Read(z))) + } yield () - ZIO.succeed( - Response( - status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)), - headers = ZioHttpHeaders(allHeaders), - body = resp.body - .map { - case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream) - case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) + case v => + channel.send(v) + } } - .getOrElse(Body.empty) - ) - ) - case RequestResult.Failure(_) => + }.toResponse + } + + case RequestResult.Failure(_) => ZIO.fail( new RuntimeException( s"The path: ${req.path} matches the shape of some endpoint, but none of the " + @@ -76,11 +78,11 @@ trait ZioHttpInterpreter[R] { } val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes) - val singleEndpoint = widenedSes.size == 1 + val singleEndpoint = widenedSes.size == 1 Http.fromOptionalHandlerZIO { request => // pre-filtering the endpoints by shape to determine, if this request should be handled by tapir - val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request)) + val filteredEndpoints = serverEndpointsFilter.apply(ZioHttpServerRequest(request)) val filteredEndpoints2 = if (singleEndpoint) { // If we are interpreting a single endpoint, we verify that the method matches as well; in case it doesn't, // we refuse to handle the request, allowing other ZIO Http routes to handle it. Otherwise even if the method @@ -98,6 +100,31 @@ trait ZioHttpInterpreter[R] { } } + private def handleHttpResponse( + resp: ServerResponse[ZioResponseBody], + body: Option[ZioHttpResponseBody] + ) = { + val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList + val allHeaders = body.flatMap(_.contentLength) match { + case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders + case _ => baseHeaders + } + val statusCode = resp.code.code + + ZIO.succeed( + Response( + status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)), + headers = ZioHttpHeaders(allHeaders), + body = body + .map { + case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream) + case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) + } + .getOrElse(Body.empty) + ) + ) + } + private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] = List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", "))) } @@ -107,7 +134,7 @@ object ZioHttpInterpreter { new ZioHttpInterpreter[R] { override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions } - def apply(): ZioHttpInterpreter[Any] = + def apply(): ZioHttpInterpreter[Any] = new ZioHttpInterpreter[Any] { override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any] } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala index b93abe1fbd..af412f40fb 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala @@ -10,26 +10,31 @@ import zio.stream.ZStream import java.nio.ByteBuffer import java.nio.charset.Charset -class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] { +class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] { override val streams: ZioStreams = ZioStreams - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioHttpResponseBody = - rawValueToEntity(bodyType, v) + override def fromRawValue[R]( + v: R, + headers: HasHeaders, + format: CodecFormat, + bodyType: RawBodyType[R] + ): ZioResponseBody = + Right(rawValueToEntity(bodyType, v)) override def fromStreamValue( v: streams.BinaryStream, headers: HasHeaders, format: CodecFormat, charset: Option[Charset] - ): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None) + ): ZioResponseBody = Right(ZioStreamHttpResponseBody(v, None)) override def fromWebSocketPipe[REQ, RESP]( pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] - ): ZioHttpResponseBody = - ZioStreamHttpResponseBody(ZStream.empty, None) // TODO + ): ZioResponseBody = + Left(ZioWebSockets.pipeToBody(pipe, o)) - private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = { + private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = bodyType match { case RawBodyType.StringBody(charset) => val bytes = r.toString.getBytes(charset) @@ -71,5 +76,4 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea .getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length))) case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported") } - } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala new file mode 100644 index 0000000000..32ef887bbf --- /dev/null +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala @@ -0,0 +1,54 @@ +package sttp.tapir.server.ziohttp +import sttp.capabilities.zio.ZioStreams +import sttp.capabilities.zio.ZioStreams.Pipe +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.tapir.{DecodeResult, WebSocketBodyOutput} +import sttp.ws.WebSocketFrame +import zio.http.{WebSocketFrame => ZWebSocketFrame} +import zio.stream.ZStream +import zio.{Chunk, ZIO} + +object ZioWebSockets { + def pipeToBody[REQ, RESP]( + pipe: Pipe[REQ, RESP], + o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] + ): F2F = { in => + ZStream + .from(in) + .map(zFrameToFrame) + .map { + case WebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None + case f => + o.requests.decode(f) match { + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + case DecodeResult.Value(v) => Some(v) + } + } + .collectWhileSome + .viaFunction(pipe) + .map(o.responses.encode) + .map(frameToZFrame) + .tap(v => zio.ZIO.succeed(println(v))) + .runFoldZIO(List.empty[ZWebSocketFrame])((s, ws) => ZIO.succeed(ws +: s)) + } + + private def zFrameToFrame(f: ZWebSocketFrame): WebSocketFrame = + f match { + case ZWebSocketFrame.Text(text) => WebSocketFrame.Text(text, f.isFinal, rsv = None) + case ZWebSocketFrame.Binary(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) + case ZWebSocketFrame.Continuation(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) + case ZWebSocketFrame.Ping => WebSocketFrame.ping + case ZWebSocketFrame.Pong => WebSocketFrame.pong + case ZWebSocketFrame.Close(status, reason) => WebSocketFrame.Close(status, reason.getOrElse("")) + case _ => WebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None) + } + + private def frameToZFrame(f: WebSocketFrame): ZWebSocketFrame = + f match { + case WebSocketFrame.Text(p, finalFragment, _) => ZWebSocketFrame.Text(p, finalFragment) + case WebSocketFrame.Binary(p, finalFragment, _) => ZWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment) + case WebSocketFrame.Ping(_) => ZWebSocketFrame.Ping + case WebSocketFrame.Pong(_) => ZWebSocketFrame.Pong + case WebSocketFrame.Close(code, reason) => ZWebSocketFrame.Close(code, Some(reason)) + } +} diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala new file mode 100644 index 0000000000..28697c399a --- /dev/null +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala @@ -0,0 +1,11 @@ +package sttp.tapir.server +import zio.Task +import zio.http.{WebSocketFrame => ZWebSocketFrame} + +package object ziohttp { + type F2F = ZWebSocketFrame => Task[List[ZWebSocketFrame]] + + type ZioResponseBody = + Either[F2F, ZioHttpResponseBody] + +} diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 9327b7d86d..07b4502598 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -39,7 +39,7 @@ class ZioHttpServerTest extends TestSuite { println(s"Test ${test.name} failed, retrying.") e.printStackTrace() (if (count == 1) super.withFixture(test) else withFixture(test, count - 1)).toFuture - case other => Future.successful(other) + case other => Future.successful(other) }) } @@ -50,23 +50,23 @@ class ZioHttpServerTest extends TestSuite { .scoped[IO, Any, ZEnvironment[EventLoopGroup with ChannelFactory[ServerChannel]]]({ val eventConfig = ZLayer.succeed(new EventLoopGroups.Config { def channelType = ChannelType.AUTO - val nThreads = 0 + val nThreads = 0 }) val channelConfig: ZLayer[Any, Nothing, ChannelType.Config] = eventConfig (channelConfig >>> ChannelFactories.Server.fromConfig) ++ (eventConfig >>> EventLoopGroups.live) }.build) .map { nettyDeps => - val eventLoopGroup = ZLayer.succeed(nettyDeps.get[EventLoopGroup]) - val channelFactory = ZLayer.succeed(nettyDeps.get[ChannelFactory[ServerChannel]]) - val interpreter = new ZioHttpTestServerInterpreter(eventLoopGroup, channelFactory) + val eventLoopGroup = ZLayer.succeed(nettyDeps.get[EventLoopGroup]) + val channelFactory = ZLayer.succeed(nettyDeps.get[ChannelFactory[ServerChannel]]) + val interpreter = new ZioHttpTestServerInterpreter(eventLoopGroup, channelFactory) val createServerTest = new DefaultCreateServerTest(backend, interpreter) def additionalTests(): List[Test] = List( // https://github.com/softwaremill/tapir/issues/1914 Test("zio http route can be called with runZIO") { - val ep = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ZIO.succeed("response")) - val route = ZioHttpInterpreter().toHttp(ep) + val ep = endpoint.get.in("p1").out(stringBody).zServerLogic[Any](_ => ZIO.succeed("response")) + val route = ZioHttpInterpreter().toHttp(ep) val test: UIO[Assertion] = route .runZIO(Request.get(url = URL.apply(Path.empty / "p1"))) .flatMap(response => response.body.asString) @@ -76,36 +76,36 @@ class ZioHttpServerTest extends TestSuite { }, Test("zio http middlewares run before the handler") { val test: UIO[Assertion] = for { - p <- Promise.make[Nothing, Unit] - ep = endpoint.get - .in("p1") - .out(stringBody) - .zServerLogic[Any](_ => p.await.timeout(time.Duration.ofSeconds(1)) *> ZIO.succeed("Ok")) - int = ZioHttpInterpreter().toHttp(ep) - route = int @@ HttpAppMiddleware.allowZIO((_: Request) => p.succeed(()).as(true)) + p <- Promise.make[Nothing, Unit] + ep = endpoint.get + .in("p1") + .out(stringBody) + .zServerLogic[Any](_ => p.await.timeout(time.Duration.ofSeconds(1)) *> ZIO.succeed("Ok")) + int = ZioHttpInterpreter().toHttp(ep) + route = int @@ HttpAppMiddleware.allowZIO((_: Request) => p.succeed(()).as(true)) result <- route - .runZIO(Request.get(url = URL(Path.empty / "p1"))) - .flatMap(response => response.body.asString) - .map(_ shouldBe "Ok") - .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + .runZIO(Request.get(url = URL(Path.empty / "p1"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "Ok") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) } yield result Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) }, Test("zio http middlewares only run once") { val test: UIO[Assertion] = for { - ref <- Ref.make(0) - ep = endpoint.get - .in("p1") - .out(stringBody) - .zServerLogic[Any](_ => ref.updateAndGet(_ + 1).map(_.toString)) - route = ZioHttpInterpreter() - .toHttp(ep) @@ HttpAppMiddleware.allowZIO((_: Request) => ref.update(_ + 1).as(true)) + ref <- Ref.make(0) + ep = endpoint.get + .in("p1") + .out(stringBody) + .zServerLogic[Any](_ => ref.updateAndGet(_ + 1).map(_.toString)) + route = ZioHttpInterpreter() + .toHttp(ep) @@ HttpAppMiddleware.allowZIO((_: Request) => ref.update(_ + 1).as(true)) result <- route - .runZIO(Request.get(url = URL(Path.empty / "p1"))) - .flatMap(response => response.body.asString) - .map(_ shouldBe "2") - .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) + .runZIO(Request.get(url = URL(Path.empty / "p1"))) + .flatMap(response => response.body.asString) + .map(_ shouldBe "2") + .catchAll(_ => ZIO.succeed(fail("Unable to extract body from Http response"))) } yield result Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) @@ -120,8 +120,7 @@ class ZioHttpServerTest extends TestSuite { val backendStub: TapirStubInterpreter[Task, ZioStreams, Unit] = TapirStubInterpreter[Task, ZioStreams](SttpBackendStub[Task, ZioStreams](new RIOMonadError[Any])) - val endpointModel - : PublicEndpoint[ZStream[Any, Throwable, Byte], Unit, ZStream[Any, Throwable, Byte], ZioStreams] = + val endpointModel: PublicEndpoint[ZStream[Any, Throwable, Byte], Unit, ZStream[Any, Throwable, Byte], ZioStreams] = endpoint.post .in("hello") .in(streamBinaryBody(ZioStreams)(CsvCodecFormat)) @@ -138,8 +137,8 @@ class ZioHttpServerTest extends TestSuite { .via(ZPipeline.utf8Encode) } ) - val inputStrings = List("Hello,how,are,you", "I,am,good,thanks") - val input: ZStream[Any, Nothing, Byte] = + val inputStrings = List("Hello,how,are,you", "I,am,good,thanks") + val input: ZStream[Any, Nothing, Byte] = ZStream(inputStrings: _*) .via(ZPipeline.intersperse(java.lang.System.lineSeparator())) .mapConcat(_.getBytes(Charset.forName("UTF-8"))) @@ -198,7 +197,7 @@ class ZioHttpServerTest extends TestSuite { new ZioHttpCompositionTest(createServerTest).tests() ++ new ServerWebSocketTests(createServerTest, ZioStreams) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) - override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty + override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty }.tests() ++ additionalTests() } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala index e8b2c8aa01..3a53e49ae0 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala @@ -19,7 +19,10 @@ class ZioHttpTestServerInterpreter( trace: Trace ) extends TestServerInterpreter[Task, ZioStreams with WebSockets, ZioHttpServerOptions[Any], Http[Any, Throwable, Request, Response]] { - override def route(es: List[ServerEndpoint[ZioStreams with WebSockets, Task]], interceptors: Interceptors): Http[Any, Throwable, Request, Response] = { + override def route( + es: List[ServerEndpoint[ZioStreams with WebSockets, Task]], + interceptors: Interceptors + ): Http[Any, Throwable, Request, Response] = { val serverOptions: ZioHttpServerOptions[Any] = interceptors(ZioHttpServerOptions.customiseInterceptors).options ZioHttpInterpreter(serverOptions).toHttp(es) } From 806e208f3d0d91be172532fdb78b961e7556cac7 Mon Sep 17 00:00:00 2001 From: Yehia AboSedira Date: Sat, 2 Sep 2023 03:03:05 +0200 Subject: [PATCH 3/4] Handle fragmented frames --- .../server/ziohttp/ZioHttpInterpreter.scala | 22 +-- .../tapir/server/ziohttp/ZioWebSockets.scala | 139 +++++++++++++----- .../sttp/tapir/server/ziohttp/package.scala | 6 +- 3 files changed, 113 insertions(+), 54 deletions(-) diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 6def95d57e..1f0e884146 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -10,7 +10,6 @@ import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter} import sttp.tapir.server.model.ServerResponse import sttp.tapir.ztapir._ import zio._ -import zio.http.ChannelEvent.Read import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} trait ZioHttpInterpreter[R] { @@ -47,21 +46,7 @@ trait ZioHttpInterpreter[R] { resp.body match { case None => handleHttpResponse(resp, None) case Some(Right(body)) => handleHttpResponse(resp, Some(body)) - case Some(Left(body)) => - Handler.webSocket { channel => - { - channel.receiveAll { - case ChannelEvent.Read(message) => - for { - m <- body(message) - _ <- ZIO.foldLeft(m)(())((_, z) => channel.send(Read(z))) - } yield () - - case v => - channel.send(v) - } - } - }.toResponse + case Some(Left(body)) => handleWebSocketResponse(body) } case RequestResult.Failure(_) => @@ -100,6 +85,10 @@ trait ZioHttpInterpreter[R] { } } + private def handleWebSocketResponse(webSocketHandler: WebSocketHandler) = { + Handler.webSocket(webSocketHandler).toResponse + } + private def handleHttpResponse( resp: ServerResponse[ZioResponseBody], body: Option[ZioHttpResponseBody] @@ -130,6 +119,7 @@ trait ZioHttpInterpreter[R] { } object ZioHttpInterpreter { + def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = new ZioHttpInterpreter[R] { override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala index 32ef887bbf..1b9f84a21a 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala @@ -3,52 +3,121 @@ import sttp.capabilities.zio.ZioStreams import sttp.capabilities.zio.ZioStreams.Pipe import sttp.tapir.model.WebSocketFrameDecodeFailure import sttp.tapir.{DecodeResult, WebSocketBodyOutput} -import sttp.ws.WebSocketFrame -import zio.http.{WebSocketFrame => ZWebSocketFrame} +import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} +import zio.http.ChannelEvent.Read +import zio.http.{ChannelEvent, WebSocketChannel, WebSocketChannelEvent, WebSocketFrame => ZioWebSocketFrame} import zio.stream.ZStream -import zio.{Chunk, ZIO} +import zio.{Chunk, Task, ZIO, stream} object ZioWebSockets { - def pipeToBody[REQ, RESP]( + private val NormalClosureStatusCode = 1000 + private val AbnormalClosureStatusCode = 1006 + + def pipeToBody[REQ, RESP](pipe: Pipe[REQ, RESP], o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]): WebSocketHandler = + channel => { + val reqToRespPipeline = reqToResp(pipe, o) + + channel.receiveAll { + case ChannelEvent.Read(ZioWebSocketFrame.Ping) if o.autoPongOnPing => + channel.send(Read(ZioWebSocketFrame.Pong)) + case ChannelEvent.Read(message) if message.isFinal => + processWebSocketFrame(channel, reqToRespPipeline, message) + case ChannelEvent.Read(message) => // Fragmented message + for { + message <- accumulateFrames(channel, message) + response <- processWebSocketFrame(channel, reqToRespPipeline, message) + } yield response + case ChannelEvent.Unregistered => + channel.send(Read(ZioWebSocketFrame.close(NormalClosureStatusCode))) + case _ => + ZIO.unit + } + } + + private def reqToResp[REQ, RESP]( pipe: Pipe[REQ, RESP], o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] - ): F2F = { in => + ): stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame] = { + in: stream.Stream[Throwable, ZioWebSocketFrame] => + in + .map(zFrameToFrame) + .map { + case SttpWebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None + case SttpWebSocketFrame.Pong(_) if o.ignorePong => None + case f: SttpWebSocketFrame => + o.requests.decode(f) match { + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + case DecodeResult.Value(v) => Some(v) + } + } + .collectWhileSome + .viaFunction(pipe) + .map(o.responses.encode) + .map(frameToZFrame) + } + + private def processWebSocketFrame( + channel: WebSocketChannel, + body: stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame], + message: ZioWebSocketFrame + ) = { ZStream - .from(in) - .map(zFrameToFrame) - .map { - case WebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None - case f => - o.requests.decode(f) match { - case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) - case DecodeResult.Value(v) => Some(v) - } - } - .collectWhileSome - .viaFunction(pipe) - .map(o.responses.encode) - .map(frameToZFrame) - .tap(v => zio.ZIO.succeed(println(v))) - .runFoldZIO(List.empty[ZWebSocketFrame])((s, ws) => ZIO.succeed(ws +: s)) + .from(message) + .viaFunction(body) + .mapZIO(wsf => channel.send(Read(wsf))) + .runDrain + } + + private def accumulateFrames(channel: WebSocketChannel, webSocketFrame: ZioWebSocketFrame): Task[ZioWebSocketFrame] = { + ZIO.iterate(webSocketFrame)(!_.isFinal) { wsf => + for { + channelEvent <- channel.receive + accumulatedWebSocketFrame <- handleChannelEvent(channel, channelEvent, wsf) + } yield accumulatedWebSocketFrame + } } - private def zFrameToFrame(f: ZWebSocketFrame): WebSocketFrame = + private def handleChannelEvent( + channel: WebSocketChannel, + channelEvent: WebSocketChannelEvent, + acc: ZioWebSocketFrame + ): Task[ZioWebSocketFrame] = { + channelEvent match { + case ChannelEvent.ExceptionCaught(cause) => + channel.send(Read(ZioWebSocketFrame.close(AbnormalClosureStatusCode, Some(cause.getMessage)))) *> + channel.shutdown.map(_ => acc) + case Read(ZioWebSocketFrame.Continuation(newBuffer)) => + acc match { + case b @ ZioWebSocketFrame.Binary(bytes) => ZIO.succeed(b.copy(bytes ++ newBuffer)) + case t @ ZioWebSocketFrame.Text(text) => ZIO.succeed(t.copy(text + new String(newBuffer.toArray))) + case ZioWebSocketFrame.Close(status, reason) => + ZIO.fail(new RuntimeException(s"Received unexpected close frame: $status, $reason")) + case ZioWebSocketFrame.Continuation(buffer) => + channel.send(Read(ZioWebSocketFrame.Continuation(buffer))).map(_ => acc) + case ZioWebSocketFrame.Ping => channel.send(Read(ZioWebSocketFrame.Pong)).map(_ => acc) + case ZioWebSocketFrame.Pong => channel.send(Read(ZioWebSocketFrame.Ping)).map(_ => acc) + case _ => ZIO.succeed(acc) + } + case _ => ZIO.succeed(acc) + } + } + private def zFrameToFrame(f: ZioWebSocketFrame): SttpWebSocketFrame = f match { - case ZWebSocketFrame.Text(text) => WebSocketFrame.Text(text, f.isFinal, rsv = None) - case ZWebSocketFrame.Binary(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) - case ZWebSocketFrame.Continuation(buffer) => WebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) - case ZWebSocketFrame.Ping => WebSocketFrame.ping - case ZWebSocketFrame.Pong => WebSocketFrame.pong - case ZWebSocketFrame.Close(status, reason) => WebSocketFrame.Close(status, reason.getOrElse("")) - case _ => WebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None) + case ZioWebSocketFrame.Text(text) => SttpWebSocketFrame.Text(text, f.isFinal, rsv = None) + case ZioWebSocketFrame.Binary(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) + case ZioWebSocketFrame.Continuation(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) + case ZioWebSocketFrame.Ping => SttpWebSocketFrame.ping + case ZioWebSocketFrame.Pong => SttpWebSocketFrame.pong + case ZioWebSocketFrame.Close(status, reason) => SttpWebSocketFrame.Close(status, reason.getOrElse("")) + case _ => SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None) } - private def frameToZFrame(f: WebSocketFrame): ZWebSocketFrame = + private def frameToZFrame(f: SttpWebSocketFrame): ZioWebSocketFrame = f match { - case WebSocketFrame.Text(p, finalFragment, _) => ZWebSocketFrame.Text(p, finalFragment) - case WebSocketFrame.Binary(p, finalFragment, _) => ZWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment) - case WebSocketFrame.Ping(_) => ZWebSocketFrame.Ping - case WebSocketFrame.Pong(_) => ZWebSocketFrame.Pong - case WebSocketFrame.Close(code, reason) => ZWebSocketFrame.Close(code, Some(reason)) + case SttpWebSocketFrame.Text(p, finalFragment, _) => ZioWebSocketFrame.Text(p, finalFragment) + case SttpWebSocketFrame.Binary(p, finalFragment, _) => ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment) + case SttpWebSocketFrame.Ping(_) => ZioWebSocketFrame.Ping + case SttpWebSocketFrame.Pong(_) => ZioWebSocketFrame.Pong + case SttpWebSocketFrame.Close(code, reason) => ZioWebSocketFrame.Close(code, Some(reason)) } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala index 28697c399a..4e03296183 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala @@ -1,11 +1,11 @@ package sttp.tapir.server import zio.Task -import zio.http.{WebSocketFrame => ZWebSocketFrame} +import zio.http.WebSocketChannel package object ziohttp { - type F2F = ZWebSocketFrame => Task[List[ZWebSocketFrame]] + type WebSocketHandler = WebSocketChannel => Task[Unit] type ZioResponseBody = - Either[F2F, ZioHttpResponseBody] + Either[WebSocketHandler, ZioHttpResponseBody] } From 834a170465dfba1d7eacdebf5814c63c03f25cc6 Mon Sep 17 00:00:00 2001 From: Yehia AboSedira Date: Sat, 16 Sep 2023 02:33:35 +0200 Subject: [PATCH 4/4] Using queues to receive and process websocket messages --- .tool-versions | 1 + doc/server/ziohttp.md | 6 + .../server/ziohttp/ZioHttpInterpreter.scala | 15 +- .../tapir/server/ziohttp/ZioWebSockets.scala | 205 +++++++++--------- .../sttp/tapir/server/ziohttp/package.scala | 7 +- .../server/ziohttp/ZioHttpServerTest.scala | 50 +++-- 6 files changed, 164 insertions(+), 120 deletions(-) create mode 100644 .tool-versions diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000000..40c0fdd91e --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +java adoptopenjdk-11.0.15+10 diff --git a/doc/server/ziohttp.md b/doc/server/ziohttp.md index 0f5d49a7fc..28df4356d0 100644 --- a/doc/server/ziohttp.md +++ b/doc/server/ziohttp.md @@ -94,6 +94,12 @@ capability. Both response bodies and request bodies can be streamed. Usage: `str The capability can be added to the classpath independently of the interpreter through the `"com.softwaremill.sttp.shared" %% "zio"` dependency. +## Web sockets + +The interpreter supports web sockets, with pipes of type `zio.stream.Stream[Throwable, REQ] => zio.stream.Stream[Throwable, RESP]`. +See [web sockets](../endpoint/websockets.md) for more details. It also supports auto-ping, auto-pong-on-ping, ignoring-pongs and handling +of fragmented frames. + ## Configuration The interpreter can be configured by providing an `ZioHttpServerOptions` value, see diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 1f0e884146..11679ab2a0 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -46,7 +46,9 @@ trait ZioHttpInterpreter[R] { resp.body match { case None => handleHttpResponse(resp, None) case Some(Right(body)) => handleHttpResponse(resp, Some(body)) - case Some(Left(body)) => handleWebSocketResponse(body) + case Some(Left(body)) => + println(body) + handleWebSocketResponse(body) } case RequestResult.Failure(_) => @@ -85,8 +87,15 @@ trait ZioHttpInterpreter[R] { } } - private def handleWebSocketResponse(webSocketHandler: WebSocketHandler) = { - Handler.webSocket(webSocketHandler).toResponse + private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = { + Handler.webSocket { channel => + for { + channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent] + messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork + webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue)) + _ <- webSocketStream.mapZIO(channel.send).runDrain + } yield messageReceptionFiber.join + }.toResponse } private def handleHttpResponse( diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala index 1b9f84a21a..48fa285c66 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala @@ -4,120 +4,129 @@ import sttp.capabilities.zio.ZioStreams.Pipe import sttp.tapir.model.WebSocketFrameDecodeFailure import sttp.tapir.{DecodeResult, WebSocketBodyOutput} import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} +import zio.Duration.fromScala import zio.http.ChannelEvent.Read -import zio.http.{ChannelEvent, WebSocketChannel, WebSocketChannelEvent, WebSocketFrame => ZioWebSocketFrame} +import zio.http.{WebSocketChannelEvent, WebSocketFrame => ZioWebSocketFrame} import zio.stream.ZStream -import zio.{Chunk, Task, ZIO, stream} +import zio.{Chunk, Schedule, ZIO, stream} -object ZioWebSockets { - private val NormalClosureStatusCode = 1000 - private val AbnormalClosureStatusCode = 1006 - - def pipeToBody[REQ, RESP](pipe: Pipe[REQ, RESP], o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]): WebSocketHandler = - channel => { - val reqToRespPipeline = reqToResp(pipe, o) +import scala.concurrent.duration.FiniteDuration - channel.receiveAll { - case ChannelEvent.Read(ZioWebSocketFrame.Ping) if o.autoPongOnPing => - channel.send(Read(ZioWebSocketFrame.Pong)) - case ChannelEvent.Read(message) if message.isFinal => - processWebSocketFrame(channel, reqToRespPipeline, message) - case ChannelEvent.Read(message) => // Fragmented message - for { - message <- accumulateFrames(channel, message) - response <- processWebSocketFrame(channel, reqToRespPipeline, message) - } yield response - case ChannelEvent.Unregistered => - channel.send(Read(ZioWebSocketFrame.close(NormalClosureStatusCode))) - case _ => - ZIO.unit - } - } +object ZioWebSockets { - private def reqToResp[REQ, RESP]( + def pipeToBody[REQ, RESP]( pipe: Pipe[REQ, RESP], o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] - ): stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame] = { - in: stream.Stream[Throwable, ZioWebSocketFrame] => - in - .map(zFrameToFrame) - .map { - case SttpWebSocketFrame.Close(_, _) if !o.decodeCloseRequests => None - case SttpWebSocketFrame.Pong(_) if o.ignorePong => None - case f: SttpWebSocketFrame => - o.requests.decode(f) match { - case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) - case DecodeResult.Value(v) => Some(v) + ): WebSocketHandler = { + { (in: stream.Stream[Throwable, WebSocketChannelEvent]) => + { + for { + pongs <- zio.Queue.bounded[SttpWebSocketFrame](1) + sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome + concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames) + ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong) + autoPongs = optionallyAutoPong(ignoredPongs, pongs, o.autoPongOnPing) + autoPing = optionallyAutoPing(o.autoPing) + closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close) + intermediateStream = autoPongs + .map { + case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None + case f => + o.requests.decode(f) match { + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + case DecodeResult.Value(v) => Some(v) + } } - } - .collectWhileSome - .viaFunction(pipe) - .map(o.responses.encode) - .map(frameToZFrame) + .collectWhileSome + .viaFunction(pipe) + .map(o.responses.encode) + .mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1)) + .mergeHaltLeft(autoPing) ++ closeStream + sendReceiveStream = intermediateStream.map(frameToZWebSocketChannelEvent) + } yield sendReceiveStream + } + } } - private def processWebSocketFrame( - channel: WebSocketChannel, - body: stream.Stream[Throwable, ZioWebSocketFrame] => stream.Stream[Throwable, ZioWebSocketFrame], - message: ZioWebSocketFrame - ) = { - ZStream - .from(message) - .viaFunction(body) - .mapZIO(wsf => channel.send(Read(wsf))) - .runDrain + private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] = + channelEvent match { + case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None)) + case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) + case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) + case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping) + case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong) + case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse(""))) + case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None)) + case _ => None + } + + private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent = + f match { + case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment)) + case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment)) + case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping) + case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong) + case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason))) + } + + private def optionallyIgnorePongs( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + ignorePong: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + sttpFrames + .filter { + case _: SttpWebSocketFrame.Pong if ignorePong => false + case _ => true + } } - private def accumulateFrames(channel: WebSocketChannel, webSocketFrame: ZioWebSocketFrame): Task[ZioWebSocketFrame] = { - ZIO.iterate(webSocketFrame)(!_.isFinal) { wsf => - for { - channelEvent <- channel.receive - accumulatedWebSocketFrame <- handleChannelEvent(channel, channelEvent, wsf) - } yield accumulatedWebSocketFrame + private def optionallyAutoPing( + autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] + ): ZStream[Any, Nothing, SttpWebSocketFrame] = { + autoPing match { + case Some((duration, ping)) => + stream.ZStream + .from(ping) + .repeat(Schedule.fixed(fromScala(duration))) + case None => stream.ZStream.empty } } - private def handleChannelEvent( - channel: WebSocketChannel, - channelEvent: WebSocketChannelEvent, - acc: ZioWebSocketFrame - ): Task[ZioWebSocketFrame] = { - channelEvent match { - case ChannelEvent.ExceptionCaught(cause) => - channel.send(Read(ZioWebSocketFrame.close(AbnormalClosureStatusCode, Some(cause.getMessage)))) *> - channel.shutdown.map(_ => acc) - case Read(ZioWebSocketFrame.Continuation(newBuffer)) => - acc match { - case b @ ZioWebSocketFrame.Binary(bytes) => ZIO.succeed(b.copy(bytes ++ newBuffer)) - case t @ ZioWebSocketFrame.Text(text) => ZIO.succeed(t.copy(text + new String(newBuffer.toArray))) - case ZioWebSocketFrame.Close(status, reason) => - ZIO.fail(new RuntimeException(s"Received unexpected close frame: $status, $reason")) - case ZioWebSocketFrame.Continuation(buffer) => - channel.send(Read(ZioWebSocketFrame.Continuation(buffer))).map(_ => acc) - case ZioWebSocketFrame.Ping => channel.send(Read(ZioWebSocketFrame.Pong)).map(_ => acc) - case ZioWebSocketFrame.Pong => channel.send(Read(ZioWebSocketFrame.Ping)).map(_ => acc) - case _ => ZIO.succeed(acc) + private def optionallyAutoPong( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + pongs: zio.Queue[SttpWebSocketFrame], + autoPongOnPing: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + if (autoPongOnPing) { + sttpFrames.mapZIO { + case _: SttpWebSocketFrame.Ping if autoPongOnPing => + pongs.offer(SttpWebSocketFrame.pong).as(Option.empty[SttpWebSocketFrame]) + case f => ZIO.succeed(Some(f)) + }.collectSome + } else sttpFrames + } + + private def optionallyConcatenate( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + concatenate: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + if (concatenate) { + type Accumulator = Option[Either[Array[Byte], String]] + + sttpFrames + .mapAccum(None: Accumulator) { + case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Close) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") } - case _ => ZIO.succeed(acc) - } + .collectSome + } else sttpFrames } - private def zFrameToFrame(f: ZioWebSocketFrame): SttpWebSocketFrame = - f match { - case ZioWebSocketFrame.Text(text) => SttpWebSocketFrame.Text(text, f.isFinal, rsv = None) - case ZioWebSocketFrame.Binary(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) - case ZioWebSocketFrame.Continuation(buffer) => SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None) - case ZioWebSocketFrame.Ping => SttpWebSocketFrame.ping - case ZioWebSocketFrame.Pong => SttpWebSocketFrame.pong - case ZioWebSocketFrame.Close(status, reason) => SttpWebSocketFrame.Close(status, reason.getOrElse("")) - case _ => SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None) - } - private def frameToZFrame(f: SttpWebSocketFrame): ZioWebSocketFrame = - f match { - case SttpWebSocketFrame.Text(p, finalFragment, _) => ZioWebSocketFrame.Text(p, finalFragment) - case SttpWebSocketFrame.Binary(p, finalFragment, _) => ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment) - case SttpWebSocketFrame.Ping(_) => ZioWebSocketFrame.Ping - case SttpWebSocketFrame.Pong(_) => ZioWebSocketFrame.Pong - case SttpWebSocketFrame.Close(code, reason) => ZioWebSocketFrame.Close(code, Some(reason)) - } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala index 4e03296183..fd18a0eeb7 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala @@ -1,9 +1,10 @@ package sttp.tapir.server -import zio.Task -import zio.http.WebSocketChannel +import zio.http.WebSocketChannelEvent +import zio.{ZIO, stream} package object ziohttp { - type WebSocketHandler = WebSocketChannel => Task[Unit] + type WebSocketHandler = + stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]] type ZioResponseBody = Either[WebSocketHandler, ZioHttpResponseBody] diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 07b4502598..881eb87f06 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -1,24 +1,46 @@ package sttp.tapir.server.ziohttp -import cats.effect.{IO, Resource} -import io.netty.channel.{ChannelFactory, EventLoopGroup, ServerChannel} -import org.scalatest.{Assertion, Exceptional, FutureOutcome} +import cats.effect.IO +import cats.effect.Resource +import io.netty.channel.ChannelFactory +import io.netty.channel.EventLoopGroup +import io.netty.channel.ServerChannel +import org.scalatest.Assertion +import org.scalatest.Exceptional +import org.scalatest.FutureOutcome import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.zio.ZioStreams import sttp.client3._ import sttp.client3.testing.SttpBackendStub import sttp.model.MediaType import sttp.monad.MonadError -import sttp.tapir.{PublicEndpoint, _} +import sttp.tapir.PublicEndpoint +import sttp.tapir._ import sttp.tapir.server.stub.TapirStubInterpreter import sttp.tapir.server.tests._ -import sttp.tapir.tests.{Test, TestSuite} -import sttp.tapir.ztapir.{RIOMonadError, RichZEndpoint} -import zio.{Promise, Ref, Runtime, Task, UIO, Unsafe, ZEnvironment, ZIO, ZLayer} -import zio.http.{HttpAppMiddleware, Path, Request, URL} -import zio.http.netty.{ChannelFactories, ChannelType, EventLoopGroups} +import sttp.tapir.tests.Test +import sttp.tapir.tests.TestSuite +import sttp.tapir.ztapir.RIOMonadError +import sttp.tapir.ztapir.RichZEndpoint +import zio.Promise +import zio.Ref +import zio.Runtime +import zio.Task +import zio.UIO +import zio.Unsafe +import zio.ZEnvironment +import zio.ZIO +import zio.ZLayer +import zio.http.HttpAppMiddleware +import zio.http.Path +import zio.http.Request +import zio.http.URL +import zio.http.netty.ChannelFactories +import zio.http.netty.ChannelType +import zio.http.netty.EventLoopGroups import zio.interop.catz._ -import zio.stream.{ZPipeline, ZStream} +import zio.stream.ZPipeline +import zio.stream.ZStream import java.nio.charset.Charset import java.time @@ -129,13 +151,13 @@ class ZioHttpServerTest extends TestSuite { val streamingEndpoint: sttp.tapir.ztapir.ZServerEndpoint[Any, ZioStreams] = endpointModel .zServerLogic(stream => - ZIO.succeed { + ZIO.succeed({ stream .via(ZPipeline.utf8Decode) .via(ZPipeline.splitLines) .via(ZPipeline.intersperse(java.lang.System.lineSeparator())) .via(ZPipeline.utf8Encode) - } + }) ) val inputStrings = List("Hello,how,are,you", "I,am,good,thanks") val input: ZStream[Any, Nothing, Byte] = @@ -195,10 +217,6 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ new ZioHttpCompositionTest(createServerTest).tests() ++ - new ServerWebSocketTests(createServerTest, ZioStreams) { - override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) - override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty - }.tests() ++ additionalTests() } }