-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3147 from yabosedira/zio-http-websocket
Initial implementation of websocket support for ZIO http.
- Loading branch information
Showing
8 changed files
with
292 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
package sttp.tapir.server.ziohttp | ||
import sttp.capabilities.zio.ZioStreams | ||
import sttp.capabilities.zio.ZioStreams.Pipe | ||
import sttp.tapir.DecodeResult | ||
import sttp.tapir.WebSocketBodyOutput | ||
import sttp.tapir.model.WebSocketFrameDecodeFailure | ||
import sttp.ws.{WebSocketFrame => SttpWebSocketFrame} | ||
import zio.Chunk | ||
import zio.Duration.fromScala | ||
import zio.Schedule | ||
import zio.ZIO | ||
import zio.http.ChannelEvent.Read | ||
import zio.http.WebSocketChannelEvent | ||
import zio.http.{WebSocketFrame => ZioWebSocketFrame} | ||
import zio.stream | ||
import zio.stream.ZStream | ||
|
||
import scala.concurrent.duration.FiniteDuration | ||
|
||
object ZioWebSockets { | ||
|
||
def pipeToBody[REQ, RESP]( | ||
pipe: Pipe[REQ, RESP], | ||
o: WebSocketBodyOutput[Pipe[REQ, RESP], REQ, RESP, _, ZioStreams] | ||
): WebSocketHandler = { | ||
{ (in: stream.Stream[Throwable, WebSocketChannelEvent]) => | ||
{ | ||
for { | ||
pongs <- zio.Queue.bounded[SttpWebSocketFrame](1) | ||
sttpFrames = in.map(zWebSocketChannelEventToFrame).collectSome | ||
concatenated = optionallyConcatenate(sttpFrames, o.concatenateFragmentedFrames) | ||
ignoredPongs = optionallyIgnorePongs(concatenated, o.ignorePong) | ||
autoPongs = optionallyAutoPongOnPing(ignoredPongs, pongs, o.autoPongOnPing) | ||
autoPing = optionallyAutoPing(o.autoPing) | ||
closeStream = stream.ZStream.from[SttpWebSocketFrame](SttpWebSocketFrame.close) | ||
intermediateStream = autoPongs | ||
.map { | ||
case _: SttpWebSocketFrame.Close if !o.decodeCloseRequests => None | ||
case f => | ||
o.requests.decode(f) match { | ||
case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) | ||
case DecodeResult.Value(v) => Some(v) | ||
} | ||
} | ||
.collectWhileSome | ||
.viaFunction(pipe) | ||
.map(o.responses.encode) | ||
.mergeHaltLeft(stream.ZStream.fromQueue[SttpWebSocketFrame](pongs, 1)) | ||
.mergeHaltLeft(autoPing) ++ closeStream | ||
sendReceiveStream = intermediateStream.map(frameToZWebSocketChannelEvent) | ||
} yield sendReceiveStream | ||
} | ||
} | ||
} | ||
|
||
private def zWebSocketChannelEventToFrame(channelEvent: WebSocketChannelEvent): Option[SttpWebSocketFrame] = | ||
channelEvent match { | ||
case Read(f @ ZioWebSocketFrame.Text(text)) => Some(SttpWebSocketFrame.Text(text, f.isFinal, rsv = None)) | ||
case Read(f @ ZioWebSocketFrame.Binary(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) | ||
case Read(f @ ZioWebSocketFrame.Continuation(buffer)) => Some(SttpWebSocketFrame.Binary(buffer.toArray, f.isFinal, rsv = None)) | ||
case Read(ZioWebSocketFrame.Ping) => Some(SttpWebSocketFrame.ping) | ||
case Read(ZioWebSocketFrame.Pong) => Some(SttpWebSocketFrame.pong) | ||
case Read(ZioWebSocketFrame.Close(status, reason)) => Some(SttpWebSocketFrame.Close(status, reason.getOrElse(""))) | ||
case Read(f) => Some(SttpWebSocketFrame.Binary(Array.empty[Byte], f.isFinal, rsv = None)) | ||
case _ => None | ||
} | ||
|
||
private def frameToZWebSocketChannelEvent(f: SttpWebSocketFrame): WebSocketChannelEvent = | ||
f match { | ||
case SttpWebSocketFrame.Text(p, finalFragment, _) => Read(ZioWebSocketFrame.Text(p, finalFragment)) | ||
case SttpWebSocketFrame.Binary(p, finalFragment, _) => Read(ZioWebSocketFrame.Binary(Chunk.fromArray(p), finalFragment)) | ||
case SttpWebSocketFrame.Ping(_) => Read(ZioWebSocketFrame.Ping) | ||
case SttpWebSocketFrame.Pong(_) => Read(ZioWebSocketFrame.Pong) | ||
case SttpWebSocketFrame.Close(code, reason) => Read(ZioWebSocketFrame.Close(code, Some(reason))) | ||
} | ||
|
||
private def optionallyIgnorePongs( | ||
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], | ||
ignorePong: Boolean | ||
): ZStream[Any, Throwable, SttpWebSocketFrame] = { | ||
sttpFrames | ||
.filter { | ||
case _: SttpWebSocketFrame.Pong if ignorePong => false | ||
case _ => true | ||
} | ||
} | ||
|
||
private def optionallyAutoPing( | ||
autoPing: Option[(FiniteDuration, SttpWebSocketFrame.Ping)] | ||
): ZStream[Any, Nothing, SttpWebSocketFrame] = { | ||
autoPing match { | ||
case Some((duration, ping)) => | ||
stream.ZStream | ||
.from(ping) | ||
.repeat(Schedule.fixed(fromScala(duration))) | ||
case None => stream.ZStream.empty | ||
} | ||
} | ||
|
||
private def optionallyAutoPongOnPing( | ||
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], | ||
pongs: zio.Queue[SttpWebSocketFrame], | ||
autoPongOnPing: Boolean | ||
): ZStream[Any, Throwable, SttpWebSocketFrame] = { | ||
if (autoPongOnPing) { | ||
sttpFrames.mapZIO { | ||
case SttpWebSocketFrame.Ping(payload) if autoPongOnPing => | ||
pongs.offer(SttpWebSocketFrame.Pong(payload)).as(Option.empty[SttpWebSocketFrame]) | ||
case f => ZIO.succeed(Some(f)) | ||
}.collectSome | ||
} else sttpFrames | ||
} | ||
|
||
private def optionallyConcatenate( | ||
sttpFrames: ZStream[Any, Throwable, SttpWebSocketFrame], | ||
concatenate: Boolean | ||
): ZStream[Any, Throwable, SttpWebSocketFrame] = { | ||
if (concatenate) { | ||
type Accumulator = Option[Either[Array[Byte], String]] | ||
|
||
sttpFrames | ||
.mapAccum(None: Accumulator) { | ||
case (None, f: SttpWebSocketFrame.Ping) => (None, Some(f)) | ||
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f)) | ||
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f)) | ||
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) | ||
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) | ||
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) | ||
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment => | ||
println(s"final fragment: $f") | ||
println(s"acc: $acc") | ||
(None, Some(f.copy(payload = acc + f.payload))) | ||
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment => | ||
println(s"final fragment: $f") | ||
println(s"acc: $acc") | ||
(Some(Right(acc + f.payload)), None) | ||
|
||
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") | ||
} | ||
.collectSome | ||
} else sttpFrames | ||
} | ||
|
||
} |
12 changes: 12 additions & 0 deletions
12
server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/package.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package sttp.tapir.server | ||
import zio.http.WebSocketChannelEvent | ||
import zio.{ZIO, stream} | ||
|
||
package object ziohttp { | ||
type WebSocketHandler = | ||
stream.Stream[Throwable, WebSocketChannelEvent] => ZIO[Any, Throwable, stream.Stream[Throwable, WebSocketChannelEvent]] | ||
|
||
type ZioResponseBody = | ||
Either[WebSocketHandler, ZioHttpResponseBody] | ||
|
||
} |
Oops, something went wrong.