diff --git a/doc/server/ziohttp.md b/doc/server/ziohttp.md index 79cfaad320..01f9d580f6 100644 --- a/doc/server/ziohttp.md +++ b/doc/server/ziohttp.md @@ -94,6 +94,12 @@ capability. Both response bodies and request bodies can be streamed. Usage: `str The capability can be added to the classpath independently of the interpreter through the `"com.softwaremill.sttp.shared" %% "zio"` dependency. +## Web sockets + +The interpreter supports web sockets, with pipes of type `zio.stream.Stream[Throwable, REQ] => zio.stream.Stream[Throwable, RESP]`. +See [web sockets](../endpoint/websockets.md) for more details. It also supports auto-ping, auto-pong-on-ping, ignoring-pongs and handling +of fragmented frames. + ## Error handling By default, any endpoints interpreted with the `ZioHttpInterpreter` will use tapir's built-in failed effect handling, diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala index 1c690a01a2..3289e5cc23 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala @@ -6,14 +6,14 @@ import zio.stream.ZStream import scala.util.{Failure, Success, Try} -private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] { - override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] = +private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] { + override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] = ZIO .environmentWithZIO[R] .apply { r => body match { - case ZioStreamHttpResponseBody(stream, contentLength) => - ZIO.succeed( + case Right(ZioStreamHttpResponseBody(stream, contentLength)) => + ZIO.right( ZioStreamHttpResponseBody( stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream .fromZIO(cb(Success(()))) @@ -22,7 +22,8 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi contentLength ) ) - case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw) + case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw) + case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws) } } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 0c9c4b14ad..217ce33b5c 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -1,22 +1,28 @@ package sttp.tapir.server.ziohttp +import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams -import sttp.model.{Method, Header => SttpHeader} +import sttp.model.Method +import sttp.model.{Header => SttpHeader} import sttp.monad.MonadError import sttp.tapir.server.interceptor.RequestResult import sttp.tapir.server.interceptor.reject.RejectInterceptor -import sttp.tapir.server.interpreter.{FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.interpreter.FilterServerEndpoints +import sttp.tapir.server.interpreter.ServerInterpreter +import sttp.tapir.server.model.ServerResponse import sttp.tapir.ztapir._ import zio._ -import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} +import zio.http.{Header => ZioHttpHeader} +import zio.http.{Headers => ZioHttpHeaders} +import zio.http._ trait ZioHttpInterpreter[R] { def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default - def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams]): HttpApp[R & R2, Throwable] = + def toHttp[R2](se: ZServerEndpoint[R2, ZioStreams with WebSockets]): HttpApp[R & R2, Throwable] = toHttp(List(se)) - def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams]]): HttpApp[R & R2, Throwable] = { + def toHttp[R2](ses: List[ZServerEndpoint[R2, ZioStreams with WebSockets]]): HttpApp[R & R2, Throwable] = { implicit val bodyListener: ZioHttpBodyListener[R & R2] = new ZioHttpBodyListener[R & R2] implicit val monadError: MonadError[RIO[R & R2, *]] = new RIOMonadError[R & R2] val widenedSes = ses.map(_.widen[R & R2]) @@ -25,9 +31,9 @@ trait ZioHttpInterpreter[R] { val zioHttpResponseBody = new ZioHttpToResponseBody val interceptors = RejectInterceptor.disableWhenSingleEndpoint(widenedServerOptions.interceptors, widenedSes) - def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams]]) = { + def handleRequest(req: Request, filteredEndpoints: List[ZServerEndpoint[R & R2, ZioStreams with WebSockets]]) = Handler.fromZIO { - val interpreter = new ServerInterpreter[ZioStreams, RIO[R & R2, *], ZioHttpResponseBody, ZioStreams]( + val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R & R2, *], ZioResponseBody, ZioStreams]( _ => filteredEndpoints, zioHttpRequestBody, zioHttpResponseBody, @@ -41,26 +47,12 @@ trait ZioHttpInterpreter[R] { error => ZIO.fail(error), { case RequestResult.Response(resp) => - val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList - val allHeaders = resp.body.flatMap(_.contentLength) match { - case Some(contentLength) if resp.contentLength.isEmpty => - ZioHttpHeader.ContentLength(contentLength) :: baseHeaders - case _ => baseHeaders + resp.body match { + case None => handleHttpResponse(resp, None) + case Some(Right(body)) => handleHttpResponse(resp, Some(body)) + case Some(Left(body)) => handleWebSocketResponse(body) } - val statusCode = resp.code.code - - ZIO.succeed( - Response( - status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)), - headers = ZioHttpHeaders(allHeaders), - body = resp.body - .map { - case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream) - case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) - } - .getOrElse(Body.empty) - ) - ) + case RequestResult.Failure(_) => ZIO.fail( new RuntimeException( @@ -73,9 +65,8 @@ trait ZioHttpInterpreter[R] { } ) } - } - val serverEndpointsFilter = FilterServerEndpoints[ZioStreams, RIO[R & R2, *]](widenedSes) + val serverEndpointsFilter = FilterServerEndpoints[ZioStreams with WebSockets, RIO[R & R2, *]](widenedSes) val singleEndpoint = widenedSes.size == 1 Http.fromOptionalHandlerZIO { request => @@ -98,19 +89,54 @@ trait ZioHttpInterpreter[R] { } } + private def handleWebSocketResponse(webSocketHandler: WebSocketHandler): ZIO[Any, Nothing, Response] = { + Handler.webSocket { channel => + for { + channelEventsQueue <- zio.Queue.unbounded[WebSocketChannelEvent] + messageReceptionFiber <- channel.receiveAll { message => channelEventsQueue.offer(message) }.fork + webSocketStream <- webSocketHandler(stream.ZStream.fromQueue(channelEventsQueue)) + _ <- webSocketStream.mapZIO(channel.send).runDrain + } yield messageReceptionFiber.join + }.toResponse + } + + private def handleHttpResponse( + resp: ServerResponse[ZioResponseBody], + body: Option[ZioHttpResponseBody] + ) = { + val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList + val allHeaders = body.flatMap(_.contentLength) match { + case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders + case _ => baseHeaders + } + val statusCode = resp.code.code + + ZIO.succeed( + Response( + status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)), + headers = ZioHttpHeaders(allHeaders), + body = body + .map { + case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream) + case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) + } + .getOrElse(Body.empty) + ) + ) + } + private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] = List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", "))) } object ZioHttpInterpreter { - def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = { + + def apply[R](serverOptions: ZioHttpServerOptions[R]): ZioHttpInterpreter[R] = new ZioHttpInterpreter[R] { override def zioHttpServerOptions: ZioHttpServerOptions[R] = serverOptions } - } - def apply(): ZioHttpInterpreter[Any] = { + def apply(): ZioHttpInterpreter[Any] = new ZioHttpInterpreter[Any] { override def zioHttpServerOptions: ZioHttpServerOptions[Any] = ZioHttpServerOptions.default[Any] } - } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala index b93abe1fbd..af412f40fb 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala @@ -10,26 +10,31 @@ import zio.stream.ZStream import java.nio.ByteBuffer import java.nio.charset.Charset -class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] { +class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] { override val streams: ZioStreams = ZioStreams - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioHttpResponseBody = - rawValueToEntity(bodyType, v) + override def fromRawValue[R]( + v: R, + headers: HasHeaders, + format: CodecFormat, + bodyType: RawBodyType[R] + ): ZioResponseBody = + Right(rawValueToEntity(bodyType, v)) override def fromStreamValue( v: streams.BinaryStream, headers: HasHeaders, format: CodecFormat, charset: Option[Charset] - ): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None) + ): ZioResponseBody = Right(ZioStreamHttpResponseBody(v, None)) override def fromWebSocketPipe[REQ, RESP]( pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] - ): ZioHttpResponseBody = - ZioStreamHttpResponseBody(ZStream.empty, None) // TODO + ): ZioResponseBody = + Left(ZioWebSockets.pipeToBody(pipe, o)) - private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = { + private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = bodyType match { case RawBodyType.StringBody(charset) => val bytes = r.toString.getBytes(charset) @@ -71,5 +76,4 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea .getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length))) case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported") } - } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala new file mode 100644 index 0000000000..1814d859b7 --- /dev/null +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala @@ -0,0 +1,144 @@ +package sttp.tapir.server.ziohttp +import sttp.capabilities.zio.ZioStreams +import sttp.capabilities.zio.ZioStreams.Pipe +import sttp.tapir.DecodeResult +import sttp.tapir.WebSocketBodyOutput +import sttp.tapir.model.WebSocketFrameDecodeFailure +import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} +import zio.Chunk +import zio.Duration.fromScala +import zio.Schedule +import zio.ZIO +import zio.http.ChannelEvent.Read +import zio.http.WebSocketChannelEvent +import zio.http.{WebSocketFrame => ZioWebSocketFrame} +import zio.stream +import zio.stream.ZStream + +import scala.concurrent.duration.FiniteDuration + +object ZioWebSockets { + + def pipeToBody[REQ, RESP]( + pipe: Pipe[REQ, RESP], + o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] + ): WebSocketHandler = { + { (in: stream.Stream[Throwable, WebSocketChannelEvent]) => + { + for { + pongs <- zio.Queue.bounded[SttpWebSocketFrame](1) + sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome + concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames) + ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong) + autoPongs = optionallyAutoPongOnPing(ignoredPongs, pongs, o.autoPongOnPing) + autoPing = optionallyAutoPing(o.autoPing) + closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close) + intermediateStream = autoPongs + .map { + case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None + case f => + o.requests.decode(f) match { + case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) + case DecodeResult.Value(v) => Some(v) + } + } + .collectWhileSome + .viaFunction(pipe) + .map(o.responses.encode) + .mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1)) + .mergeHaltLeft(autoPing) ++ closeStream + sendReceiveStream = intermediateStream.map(frameToZWebSocketChannelEvent) + } yield sendReceiveStream + } + } + } + + private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] = + channelEvent match { + case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None)) + case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) + case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) + case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping) + case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong) + case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse(""))) + case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None)) + case _ => None + } + + private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent = + f match { + case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment)) + case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment)) + case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping) + case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong) + case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason))) + } + + private def optionallyIgnorePongs( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + ignorePong: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + sttpFrames + .filter { + case _: SttpWebSocketFrame.Pong if ignorePong => false + case _ => true + } + } + + private def optionallyAutoPing( + autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] + ): ZStream[Any, Nothing, SttpWebSocketFrame] = { + autoPing match { + case Some((duration, ping)) => + stream.ZStream + .from(ping) + .repeat(Schedule.fixed(fromScala(duration))) + case None => stream.ZStream.empty + } + } + + private def optionallyAutoPongOnPing( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + pongs: zio.Queue[SttpWebSocketFrame], + autoPongOnPing: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + if (autoPongOnPing) { + sttpFrames.mapZIO { + case SttpWebSocketFrame.Ping(payload) if autoPongOnPing => + pongs.offer(SttpWebSocketFrame.Pong(payload)).as(Option.empty[SttpWebSocketFrame]) + case f => ZIO.succeed(Some(f)) + }.collectSome + } else sttpFrames + } + + private def optionallyConcatenate( + sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], + concatenate: Boolean + ): ZStream[Any, Throwable, SttpWebSocketFrame] = { + if (concatenate) { + type Accumulator = Option[Either[Array[Byte], String]] + + sttpFrames + .mapAccum(None: Accumulator) { + case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Close) => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment => + println(s"final fragment: $f") + println(s"acc: $acc") + (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment => + println(s"final fragment: $f") + println(s"acc: $acc") + (Some(Right(acc + f.payload)), None) + + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + } + .collectSome + } else sttpFrames + } + +} diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala new file mode 100644 index 0000000000..fd18a0eeb7 --- /dev/null +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala @@ -0,0 +1,12 @@ +package sttp.tapir.server +import zio.http.WebSocketChannelEvent +import zio.{ZIO, stream} + +package object ziohttp { + type WebSocketHandler = + stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]] + + type ZioResponseBody = + Either[WebSocketHandler, ZioHttpResponseBody] + +} diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 881eb87f06..03a577d974 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.ziohttp import cats.effect.IO import cats.effect.Resource +import cats.implicits.toTraverseOps import io.netty.channel.ChannelFactory import io.netty.channel.EventLoopGroup import io.netty.channel.ServerChannel @@ -22,6 +23,8 @@ import sttp.tapir.tests.Test import sttp.tapir.tests.TestSuite import sttp.tapir.ztapir.RIOMonadError import sttp.tapir.ztapir.RichZEndpoint +import sttp.ws.WebSocket +import sttp.ws.WebSocketFrame import zio.Promise import zio.Ref import zio.Runtime @@ -39,13 +42,14 @@ import zio.http.netty.ChannelFactories import zio.http.netty.ChannelType import zio.http.netty.EventLoopGroups import zio.interop.catz._ +import zio.stream import zio.stream.ZPipeline import zio.stream.ZStream import java.nio.charset.Charset import java.time import scala.concurrent.Future - +import scala.concurrent.duration.DurationInt class ZioHttpServerTest extends TestSuite { // zio-http tests often fail with "Cause: java.io.IOException: parsing HTTP/1.1 status line, receiving [DEFAULT], parser state [STATUS_LINE]" @@ -190,6 +194,45 @@ class ZioHttpServerTest extends TestSuite { } Unsafe.unsafe(implicit u => r.unsafe.runToFuture(test)) + }, + createServerTest.testServer( + endpoint + .out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain] + .apply(ZioStreams) + .autoPing(Some((1.second, WebSocketFrame.ping))) + ), + "auto pings" + )((_: Unit) => ZIO.right((in: stream.Stream[Throwable, String]) => in.map(v => s"echo $v"))) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + List(ws.receive().timeout(60.seconds), ws.receive().timeout(60.seconds)).sequence + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map(_.body should matchPattern { case Right(List(WebSocketFrame.Ping(_), WebSocketFrame.Ping(_))) => }) + }, + createServerTest.testServer( + endpoint + .out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain] + .apply(ZioStreams) + .autoPing(None) + ), + "ping-pong echo" + )((_: Unit) => ZIO.right((in: stream.Stream[Throwable, String]) => in.map(v => s"echo $v"))) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.send(WebSocketFrame.ping) + m2 <- ws.receive() + } yield List(m2) + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { response => + response.body should matchPattern { case Right(List(_: WebSocketFrame.Pong)) => } + } } ) @@ -217,6 +260,10 @@ class ZioHttpServerTest extends TestSuite { ).tests() ++ new ServerStreamingTests(createServerTest, ZioStreams).tests() ++ new ZioHttpCompositionTest(createServerTest).tests() ++ + new ServerWebSocketTests(createServerTest, ZioStreams) { + override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) + override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty + }.tests() ++ additionalTests() } } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala index bb2e5a2bbd..3a53e49ae0 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpTestServerInterpreter.scala @@ -3,6 +3,7 @@ package sttp.tapir.server.ziohttp import cats.data.NonEmptyList import cats.effect.{IO, Resource} import io.netty.channel.{ChannelFactory, EventLoopGroup, ServerChannel} +import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.tests.TestServerInterpreter @@ -16,9 +17,12 @@ class ZioHttpTestServerInterpreter( channelFactory: ZLayer[Any, Nothing, ChannelFactory[ServerChannel]] )(implicit trace: Trace -) extends TestServerInterpreter[Task, ZioStreams, ZioHttpServerOptions[Any], Http[Any, Throwable, Request, Response]] { +) extends TestServerInterpreter[Task, ZioStreams with WebSockets, ZioHttpServerOptions[Any], Http[Any, Throwable, Request, Response]] { - override def route(es: List[ServerEndpoint[ZioStreams, Task]], interceptors: Interceptors): Http[Any, Throwable, Request, Response] = { + override def route( + es: List[ServerEndpoint[ZioStreams with WebSockets, Task]], + interceptors: Interceptors + ): Http[Any, Throwable, Request, Response] = { val serverOptions: ZioHttpServerOptions[Any] = interceptors(ZioHttpServerOptions.customiseInterceptors).options ZioHttpInterpreter(serverOptions).toHttp(es) }