Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websockets support in zio server #1789

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/server/ziohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ capability. Both response bodies and request bodies can be streamed. Usage: `str
The capability can be added to the classpath independently of the interpreter through the
`"com.softwaremill.sttp.shared" %% "zio"` dependency.

## Web sockets

The interpreter supports web sockets, with pipes of type `zio.stream.Stream[Throwable, REQ] => zio.stream.Stream[Throwable, RESP]`.
See [web sockets](../endpoint/websockets.md) for more details.

## Configuration

The interpreter can be configured by providing an `ZioHttpServerOptions` value, see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new ServerStreamingTests(createServerTest, AkkaStreams).tests() ++
new ServerWebSocketTests(createServerTest, AkkaStreams) {
new ServerWebSocketTests(createServerTest, AkkaStreams, concatenateFragmentedFrames = false) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f)
}.tests() ++
additionalTests()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PlayServerTest extends TestSuite {
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, reject = false).tests() ++
new ServerStreamingTests(createServerTest, AkkaStreams).tests() ++
new PlayServerWithContextTest(backend).tests() ++
new ServerWebSocketTests(createServerTest, AkkaStreams) {
new ServerWebSocketTests(createServerTest, AkkaStreams, concatenateFragmentedFrames = false) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = Flow.fromFunction(f)
}.tests()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import sttp.ws.{WebSocket, WebSocketFrame}

abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE](
createServerTest: CreateServerTest[F, S with WebSockets, ROUTE],
val streams: S
val streams: S,
val concatenateFragmentedFrames: Boolean = true
)(implicit
m: MonadError[F]
) {
Expand All @@ -29,106 +30,143 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], ROUTE](
private def stringWs = webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)
private def stringEcho = functionToPipe((s: String) => s"echo: $s")

def tests(): List[Test] = List(
testServer(
endpoint.out(stringWs),
"string client-terminated echo"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.sendText("test2")
m1 <- ws.receiveText()
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("echo: test1", "echo: test2")))
}, {
def tests(): List[Test] = {
val basicTests = List(
testServer(
endpoint.out(stringWs),
"string client-terminated echo"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
m1 <- ws.receiveText()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the changed ordering of operations? didn't the original test work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the order of received messages was not deterministic for zio-server

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm well that's bad. Either a bug in zio-http, or somewhere else. I'd revert those tests to their original ordering and look for the root cause. That's quite normal WS usage, that you send multiple messages in sequence and then get a response

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's probably zio-http, I'll keep that in mind and see what will came up in the issue I reported - that may be related I think

_ <- ws.sendText("test2")
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("echo: test1", "echo: test2")))
}, {

val reqCounter = newRequestCounter[F]
val resCounter = newResponseCounter[F]
val metrics = new MetricsRequestInterceptor[F](List(reqCounter, resCounter), Seq.empty)
val reqCounter = newRequestCounter[F]
val resCounter = newResponseCounter[F]
val metrics = new MetricsRequestInterceptor[F](List(reqCounter, resCounter), Seq.empty)

testServer(endpoint.out(stringWs).name("metrics"), metricsInterceptor = metrics.some)((_: Unit) =>
pureResult(stringEcho.asRight[Unit])
testServer(endpoint.out(stringWs).name("metrics"), metricsInterceptor = metrics.some)((_: Unit) =>
pureResult(stringEcho.asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
m <- ws.receiveText()
} yield List(m)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { r =>
r.body shouldBe Right(List("echo: test1"))
reqCounter.metric.value.get() shouldBe 1
resCounter.metric.value.get() shouldBe 1
}
}
},
testServer(endpoint.out(webSocketBody[Fruit, CodecFormat.Json, Fruit, CodecFormat.Json](streams)), "json client-terminated echo")(
(_: Unit) => pureResult(functionToPipe((f: Fruit) => Fruit(s"echo: ${f.f}")).asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("""{"f":"apple"}""")
m1 <- ws.receiveText()
_ <- ws.sendText("""{"f":"orange"}""")
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("""{"f":"echo: apple"}""", """{"f":"echo: orange"}""")))
},
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, Option[String], CodecFormat.TextPlain](streams)),
"string server-terminated echo"
)((_: Unit) =>
pureResult(functionToPipe[String, Option[String]] {
case "end" => None
case msg => Some(s"echo: $msg")
}.asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
m <- ws.receiveText()
} yield List(m)
m1 <- ws.eitherClose(ws.receiveText())
_ <- ws.sendText("test2")
m2 <- ws.eitherClose(ws.receiveText())
_ <- ws.sendText("end")
m3 <- ws.eitherClose(ws.receiveText())
} yield List(m1, m2, m3)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { r =>
r.body shouldBe Right(List("echo: test1"))
reqCounter.metric.value.get() shouldBe 1
resCounter.metric.value.get() shouldBe 1
}
}
},
testServer(endpoint.out(webSocketBody[Fruit, CodecFormat.Json, Fruit, CodecFormat.Json](streams)), "json client-terminated echo")(
(_: Unit) => pureResult(functionToPipe((f: Fruit) => Fruit(s"echo: ${f.f}")).asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("""{"f":"apple"}""")
_ <- ws.sendText("""{"f":"orange"}""")
m1 <- ws.receiveText()
m2 <- ws.receiveText()
} yield List(m1, m2)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right(List("""{"f":"echo: apple"}""", """{"f":"echo: orange"}""")))
},
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, Option[String], CodecFormat.TextPlain](streams)),
"string server-terminated echo"
)((_: Unit) =>
pureResult(functionToPipe[String, Option[String]] {
case "end" => None
case msg => Some(s"echo: $msg")
}.asRight[Unit])
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.sendText("test2")
_ <- ws.sendText("end")
m1 <- ws.eitherClose(ws.receiveText())
m2 <- ws.eitherClose(ws.receiveText())
m3 <- ws.eitherClose(ws.receiveText())
} yield List(m1, m2, m3)
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(
_.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right(
List(Right("echo: test1"), Right("echo: test2"), Left(WebSocketFrame.close.statusCode))
.map(
_.body.map(_.map(_.left.map(_.statusCode))) shouldBe Right(
List(Right("echo: test1"), Right("echo: test2"), Left(WebSocketFrame.close.statusCode))
)
)
)
},
testServer(
endpoint
.in(isWebSocket)
.errorOut(stringBody)
.out(stringWs),
"non web-socket request"
)(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) =>
basicRequest
.response(asString)
.get(baseUri.scheme("http"))
.send(backend)
.map(_.body shouldBe Left("Not a WS!"))
}
)
},
testServer(
endpoint
.in(isWebSocket)
.errorOut(stringBody)
.out(stringWs),
"non web-socket request"
)(isWS => if (isWS) pureResult(stringEcho.asRight) else pureResult("Not a WS!".asLeft)) { (backend, baseUri) =>
basicRequest
.response(asString)
.get(baseUri.scheme("http"))
.send(backend)
.map(_.body shouldBe Left("Not a WS!"))
},
testServer(
endpoint.out(stringWs),
"pong on ping"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.ping)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I'm not sure if all backends give acces to ping-pongs, so if this fails e.g. on akka-http we might need to parametrize the tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it passed

pong <- ws.receive()
} yield pong
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body.map(_.isInstanceOf[WebSocketFrame.Pong]) shouldBe Right(true))
}
)

val concatenateFramesTest = List(
testServer(
endpoint.out(stringWs.concatenateFragmentedFrames(true)),
"concatenate fragmented frames"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Text("hello-", finalFragment = false, rsv = None))
_ <- ws.send(WebSocketFrame.Text("from-", finalFragment = false, rsv = None))
_ <- ws.send(WebSocketFrame.Text("server", finalFragment = true, rsv = None))
text <- ws.receiveText()
} yield text
})
.get(baseUri.scheme("ws"))
.send(backend)
.map(_.body shouldBe Right("echo: hello-from-server"))
}
)

// TODO: tests for ping/pong (control frames handling)
basicTests ++ (if (concatenateFragmentedFrames) concatenateFramesTest else Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import zio.stream.ZStream

import scala.util.{Failure, Success, Try}

class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZStream[Any, Throwable, Byte]] {
override def onComplete(body: ZStream[Any, Throwable, Byte])(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZStream[Any, Throwable, Byte]] =
RIO
.access[R]
.apply(r => body.onError(cause => cb(Failure(cause.squash)).orDie.provide(r)) ++ ZStream.fromEffect(cb(Success(()))).provide(r).drain)
class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioResponseBody] {
override def onComplete(body: ZioResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioResponseBody] =
body match {
case ws @ Left(_) => cb(Success(())).map(_ => ws)
case Right(stream) => RIO
.access[R]
.apply(r => Right(stream.onError(cause => cb(Failure(cause.squash)).orDie.provide(r)) ++ ZStream.fromEffect(cb(Success(()))).provide(r).drain))
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sttp.tapir.server.ziohttp

import io.netty.handler.codec.http.HttpResponseStatus
import sttp.capabilities.WebSockets
import sttp.capabilities.zio.ZioStreams
import sttp.model.{Header => SttpHeader}
import sttp.monad.MonadError
Expand All @@ -9,20 +10,20 @@ import sttp.tapir.server.interpreter.ServerInterpreter
import sttp.tapir.server.ziohttp.ZioHttpInterpreter.zioMonadError
import sttp.tapir.ztapir._
import zhttp.http.{Http, HttpData, Request, Response, Status, Header => ZioHttpHeader, Headers => ZioHttpHeaders}
import zhttp.socket._
import zio._
import zio.stream.Stream

trait ZioHttpInterpreter[R] {

def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default

def toHttp(se: ZServerEndpoint[R, ZioStreams]): Http[R, Throwable, Request, Response[R, Throwable]] =
def toHttp(se: ZServerEndpoint[R, ZioStreams with WebSockets]): Http[R, Throwable, Request, Response[R, Throwable]] =
toHttp(List(se))

def toHttp(ses: List[ZServerEndpoint[R, ZioStreams]]): Http[R, Throwable, Request, Response[R, Throwable]] = {
def toHttp(ses: List[ZServerEndpoint[R, ZioStreams with WebSockets]]): Http[R, Throwable, Request, Response[R, Throwable]] = {
implicit val bodyListener: ZioHttpBodyListener[R] = new ZioHttpBodyListener[R]
implicit val monadError: MonadError[RIO[R, *]] = zioMonadError[R]
val interpreter = new ServerInterpreter[ZioStreams, RIO[R, *], Stream[Throwable, Byte], ZioStreams](
val interpreter = new ServerInterpreter[ZioStreams with WebSockets, RIO[R, *], ZioResponseBody, ZioStreams](
ses,
new ZioHttpToResponseBody,
zioHttpServerOptions.interceptors,
Expand All @@ -36,12 +37,14 @@ trait ZioHttpInterpreter[R] {
.apply(new ZioHttpServerRequest(req), new ZioHttpRequestBody(req, new ZioHttpServerRequest(req), zioHttpServerOptions))
.map {
case RequestResult.Response(resp) =>
val status = Status.fromHttpResponseStatus(HttpResponseStatus.valueOf(resp.code.code))
val headers = ZioHttpHeaders(resp.headers.groupBy(_.name).map(sttpToZioHttpHeader).toList)
Http.succeed(
Response(
status = Status.fromHttpResponseStatus(HttpResponseStatus.valueOf(resp.code.code)),
headers = ZioHttpHeaders(resp.headers.groupBy(_.name).map(sttpToZioHttpHeader).toList),
data = resp.body.map(stream => HttpData.fromStream(stream)).getOrElse(HttpData.empty)
)
resp.body match {
case None => Response(status, headers, HttpData.empty)
case Some(Left(socket)) => asResponse(SocketApp.message(socket))
case Some(Right(stream)) => Response(status, headers, HttpData.fromStream(stream))
}
)
case RequestResult.Failure(_) => Http.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,42 @@ import sttp.tapir.server.interpreter.ToResponseBody
import sttp.tapir.{CodecFormat, FileRange, RawBodyType, WebSocketBodyOutput}
import zio.Chunk
import zio.blocking.Blocking
import zio.stream.{Stream, ZStream}
import zio.stream.Stream

import java.nio.charset.Charset

class ZioHttpToResponseBody extends ToResponseBody[ZStream[Any, Throwable, Byte], ZioStreams] {
class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] {
override val streams: ZioStreams = ZioStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZStream[Any, Throwable, Byte] =
override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): ZioResponseBody =
rawValueToEntity(bodyType, v)

override def fromStreamValue(
v: streams.BinaryStream,
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): ZStream[Any, Throwable, Byte] = v
): ZioResponseBody = Right(v)

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): ZStream[Any, Throwable, Byte] =
Stream.empty // TODO
): ZioResponseBody =
Left(ZioWebSockets.pipeToBody(pipe, o))

private def rawValueToEntity[CF <: CodecFormat, R](bodyType: RawBodyType[R], r: R): ZStream[Any, Throwable, Byte] = {
private def rawValueToEntity[CF <: CodecFormat, R](bodyType: RawBodyType[R], r: R): ZioResponseBody = Right {
bodyType match {
case RawBodyType.StringBody(charset) => ZStream.fromIterable(r.toString.getBytes(charset))
case RawBodyType.StringBody(charset) => Stream.fromIterable(r.toString.getBytes(charset))
case RawBodyType.ByteArrayBody => Stream.fromChunk(Chunk.fromArray(r))
case RawBodyType.ByteBufferBody => Stream.fromChunk(Chunk.fromByteBuffer(r))
case RawBodyType.InputStreamBody => ZStream.fromInputStream(r).provideLayer(Blocking.live)
case RawBodyType.InputStreamBody => Stream.fromInputStream(r).provideLayer(Blocking.live)
case RawBodyType.FileBody =>
val tapirFile = r.asInstanceOf[FileRange]
tapirFile.range
.flatMap(r =>
r.startAndEnd.map(s => ZStream.fromFile(tapirFile.file.toPath).drop(s._1).take(r.contentLength).provideLayer(Blocking.live))
r.startAndEnd.map(s => Stream.fromFile(tapirFile.file.toPath).drop(s._1).take(r.contentLength).provideLayer(Blocking.live))
)
.getOrElse(ZStream.fromFile(tapirFile.file.toPath).provideLayer(Blocking.live))
.getOrElse(Stream.fromFile(tapirFile.file.toPath).provideLayer(Blocking.live))
case RawBodyType.MultipartBody(_, _) => Stream.empty
}
}
Expand Down
Loading