Skip to content

Commit

Permalink
Ox WebSockets (#2187)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored May 27, 2024
1 parent 092a8d0 commit 4b5a428
Show file tree
Hide file tree
Showing 6 changed files with 479 additions and 20 deletions.
46 changes: 35 additions & 11 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@ jobs:
matrix:
scala-version: [ "2.12", "2.13", "3" ]
target-platform: [ "JVM", "JS", "Native" ]
java: [ "11", "21" ]
exclude:
- java: "21"
include: # Restricted to build only specific Loom-based modules
- scala-version: "3"
target-platform: "JVM"
java: "21"
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v1
- name: Set up JDK
uses: actions/setup-java@v4
with:
java-version: 11
distribution: 'temurin'
cache: 'sbt'
java-version: ${{ matrix.java }}
- name: Cache sbt
uses: actions/cache@v2
with:
Expand All @@ -40,14 +49,18 @@ jobs:
sudo apt-get update
sudo apt-get install libidn2-dev libcurl3-dev
echo "STTP_NATIVE=1" >> $GITHUB_ENV
- name: Enable Loom-specific modules
if: matrix.java == '21'
run: echo "ONLY_LOOM=1" >> $GITHUB_ENV
- name: Compile
run: sbt -v "compileScoped ${{ matrix.scala-version }} ${{ matrix.target-platform }}"
- name: Compile documentation
if: matrix.target-platform == 'JVM'
if: matrix.target-platform == 'JVM' && matrix.java == '11'
run: sbt -v compileDocs
- name: Test
run: sbt -v "testScoped ${{ matrix.scala-version }} ${{ matrix.target-platform }}"
- name: Prepare release notes
if: matrix.java == '11'
uses: release-drafter/release-drafter@v5
with:
config-name: release-drafter.yml
Expand Down Expand Up @@ -100,16 +113,19 @@ jobs:
name: Publish release
needs: [ci]
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v'))
runs-on: ubuntu-20.04
env:
STTP_NATIVE: 1
runs-on: ubuntu-22.04
strategy:
matrix:
java: [ "11", "21" ]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up JDK 11
uses: actions/setup-java@v1
- name: Set up JDK
uses: actions/setup-java@v4
with:
java-version: 11
distribution: 'temurin'
java-version: ${{ matrix.java }}
cache: 'sbt'
- name: Cache sbt
uses: actions/cache@v2
with:
Expand All @@ -122,6 +138,12 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install libidn2-dev libcurl3-dev
- name: Enable Native-specific modules
if: matrix.java == '11'
run: echo "STTP_NATIVE=1" >> $GITHUB_ENV
- name: Enable Loom-specific modules
if: matrix.java == '21'
run: echo "ONLY_LOOM=1" >> $GITHUB_ENV
- name: Compile
run: sbt compile
- name: Publish artifacts
Expand All @@ -132,12 +154,14 @@ jobs:
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
- name: Extract version from commit message
if: matrix.java == '11'
run: |
version=${GITHUB_REF/refs\/tags\/v/}
echo "VERSION=$version" >> $GITHUB_ENV
env:
COMMIT_MSG: ${{ github.event.head_commit.message }}
- name: Publish release notes
if: matrix.java == '11'
uses: release-drafter/release-drafter@v5
with:
config-name: release-drafter.yml
Expand Down Expand Up @@ -191,4 +215,4 @@ jobs:
uses: "pascalgn/[email protected]"
env:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
MERGE_METHOD: "squash"
MERGE_METHOD: "squash"
65 changes: 57 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ val zio2Version = "2.1.1"
val zio1InteropRsVersion = "1.3.12"
val zio2InteropRsVersion = "2.0.2"

val oxVersion = "0.2.0"
val sttpModelVersion = "1.7.10"
val sttpSharedVersion = "1.3.18"

Expand All @@ -168,24 +169,44 @@ val openTelemetryVersion = "1.38.0"

val compileAndTest = "compile->compile;test->test"

lazy val projectsWithOptionalNative: Seq[ProjectReference] = {
val base = core.projectRefs ++ jsonCommon.projectRefs ++ upickle.projectRefs
if (sys.env.isDefinedAt("STTP_NATIVE")) {
println("[info] STTP_NATIVE defined, including sttp-native in the aggregate projects")
base
lazy val loomProjects: Seq[String] = Seq(ox, examples3).flatMap(_.projectRefs).flatMap(projectId)

def projectId(projectRef: ProjectReference): Option[String] =
projectRef match {
case ProjectRef(_, id) => Some(id)
case LocalProject(id) => Some(id)
case _ => None
}

lazy val allAggregates: Seq[ProjectReference] = {
val filteredByNative = if (sys.env.isDefinedAt("STTP_NATIVE")) {
println("[info] STTP_NATIVE defined, including native in the aggregate projects")
rawAllAggregates
} else {
println("[info] STTP_NATIVE *not* defined, *not* including sttp-native in the aggregate projects")
base.filterNot(_.toString.contains("Native"))
println("[info] STTP_NATIVE *not* defined, *not* including native in the aggregate projects")
rawAllAggregates.filterNot(_.toString.contains("Native"))
}
if (sys.env.isDefinedAt("ONLY_LOOM")) {
println("[info] ONLY_LOOM defined, including only loom-based projects")
filteredByNative.filter(p => projectId(p).forall(loomProjects.contains))
} else if (sys.env.isDefinedAt("ALSO_LOOM")) {
println("[info] ALSO_LOOM defined, including also loom-based projects")
filteredByNative
} else {
println("[info] ONLY_LOOM *not* defined, *not* including loom-based-projects")
filteredByNative.filterNot(p => projectId(p).forall(loomProjects.contains))
}
}

lazy val allAggregates = projectsWithOptionalNative ++

lazy val rawAllAggregates =
testCompilation.projectRefs ++
catsCe2.projectRefs ++
cats.projectRefs ++
fs2Ce2.projectRefs ++
fs2.projectRefs ++
monix.projectRefs ++
ox.projectRefs ++
scalaz.projectRefs ++
zio1.projectRefs ++
zio.projectRefs ++
Expand Down Expand Up @@ -231,6 +252,7 @@ lazy val allAggregates = projectsWithOptionalNative ++
slf4jBackend.projectRefs ++
examplesCe2.projectRefs ++
examples.projectRefs ++
examples3.projectRefs ++
docs.projectRefs ++
testServer.projectRefs

Expand Down Expand Up @@ -439,6 +461,18 @@ lazy val monix = (projectMatrix in file("effects/monix"))
settings = commonJsSettings ++ commonJsBackendSettings ++ browserChromeTestSettings ++ testServerSettings
)

lazy val ox = (projectMatrix in file("effects/ox"))
.settings(commonJvmSettings)
.settings(
name := "ox",
libraryDependencies ++= Seq(
"com.softwaremill.ox" %% "core" % oxVersion
)
)
.settings(testServerSettings)
.jvmPlatform(scalaVersions = scala3)
.dependsOn(core % compileAndTest)

lazy val zio1 = (projectMatrix in file("effects/zio1"))
.settings(
name := "zio1",
Expand Down Expand Up @@ -1038,6 +1072,21 @@ lazy val examples = (projectMatrix in file("examples"))
slf4jBackend
)

lazy val examples3 = (projectMatrix in file("examples3"))
.settings(commonJvmSettings)
.settings(
name := "examples3",
publish / skip := true,
libraryDependencies ++= Seq(
logback
)
)
.jvmPlatform(scalaVersions = scala3)
.dependsOn(
core,
ox
)

//TODO this should be invoked by compilation process, see #https://github.com/scalameta/mdoc/issues/355
val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output")
compileDocs := {
Expand Down
48 changes: 47 additions & 1 deletion docs/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,52 @@ effect type class name
================ ==========================================
```

## WebSockets as Ox Source and Sink

[Ox](https://ox.softwaremill.com) is a Scala 3 toolkit that allows you to handle concurrency and resiliency in direct-style, leveraging Java 21 virtual threads.
If you're using Ox with `sttp`, you can use the `DefaultSyncBackend` from `sttp-core` for HTTP communication. An additional `ox` module allows handling WebSockets
as Ox `Source` and `Sink`:

```
// sbt dependency
"com.softwaremill.sttp.client4" %% "ox" % "@VERSION@",
```

```scala
import ox.*
import ox.channels.{Sink, Source}
import sttp.client4.*
import sttp.client4.impl.ox.ws.* // import to access asSourceAnkSink
import sttp.client4.ws.SyncWebSocket
import sttp.client4.ws.sync.*
import sttp.ws.WebSocketFrame

def useWebSocket(ws: SyncWebSocket): Unit =
supervised {
val (wsSource, wsSink) = asSourceAndSink(ws) // (Source[WebSocketFrame], Sink[WebSocketFrame])
// ...
}

val backend = DefaultSyncBackend()
basicRequest
.get(uri"wss://ws.postman-echo.com/raw")
.response(asWebSocket(useWebSocket))
.send(backend)
```

See the [full example here](https://github.com/softwaremill/sttp/blob/master/examples/src/main/scala/sttp/client4/examples3/WebSocketOx.scala).

Make sure that the `Source` is contiunually read. This will guarantee that server-side Close signal is received and handled.
If you don't want to process frames from the server, you can at least handle it with a `fork { source.drain() }`.

You don't need to manually call `ws.close()` when using this approach, this will be handled automatically underneath,
according to following rules:
- If the request `Sink` is closed due to an upstream error, a Close frame is sent, and the `Source` with incoming responses gets completed as `Done`.
- If the request `Sink` completes as `Done`, a `Close` frame is sent, and the response `Sink` keeps receiving responses until the server closes communication.
- If the response `Source` is closed by a Close frome from the server or due to an error, the request Sink is closed as `Done`, which will still send all outstanding buffered frames, and then finish.

Read more about Ox, structured concurrency, Sources and Sinks on the [project website](https://ox.softwaremill.com).

## Compression

For those who plan to use a lot of websocket traffic, you could consider websocket compression. See the information on
Expand Down Expand Up @@ -122,4 +168,4 @@ Web socket settings can be adjusted by providing a custom `AsyncHttpClientConfig
Some available settings:

* maximum web socket frame size. Default: 10240, can be changed using `.setWebSocketMaxFrameSize`.
* compression. Default: false, can be changed using: `.setEnablewebSocketCompression`.
* compression. Default: false, can be changed using: `.setEnablewebSocketCompression`.
119 changes: 119 additions & 0 deletions effects/ox/src/main/scala/sttp/client4/impl/ox/ws/OxWebSockets.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package sttp.client4.impl.ox.ws

import ox.*
import ox.channels.*
import sttp.client4.ws.SyncWebSocket
import sttp.ws.WebSocketFrame

import scala.util.control.NonFatal

/** Converts a [[SyncWebSocket]] into a pair of `Source` of server responses and a `Sink` for client requests. The
* `Source` starts receiving frames immediately, its internal buffer size can be adjusted with an implicit
* [[ox.channels.StageCapacity]]. Make sure that the `Source` is contiunually read. This will guarantee that
* server-side Close signal is received and handled. If you don't want to process frames from the server, you can at
* least handle it with a `fork { source.drain() }`.
*
* You don't need to manually call `ws.close()` when using this approach, this will be handled automatically
* underneath, according to following rules:
* - If the request `Sink` is closed due to an upstream error, a Close frame is sent, and the `Source` with incoming
* responses gets completed as `Done`.
* - If the request `Sink` completes as `Done`, a `Close` frame is sent, and the response `Sink` keeps receiving
* responses until the server closes communication.
* - If the response `Source` is closed by a Close frome from the server or due to an error, the request Sink is
* closed as `Done`, which will still send all outstanding buffered frames, and then finish.
*
* @param ws
* a `SyncWebSocket` where the underlying `Sink` will send requests, and where the `Source` will pull responses from.
* @param concatenateFragmented
* whether fragmented frames from the server should be concatenated into a single frame (true by default).
*/
def asSourceAndSink(ws: SyncWebSocket, concatenateFragmented: Boolean = true)(using
Ox,
StageCapacity
): (Source[WebSocketFrame], Sink[WebSocketFrame]) =
val requestsChannel = StageCapacity.newChannel[WebSocketFrame]
val responsesChannel = StageCapacity.newChannel[WebSocketFrame]
fork {
try
repeatWhile {
ws.receive() match
case frame: WebSocketFrame.Data[_] =>
responsesChannel.sendOrClosed(frame) match
case _: ChannelClosed => false
case _ => true
case WebSocketFrame.Close(status, msg) if status > 1001 =>
responsesChannel.errorOrClosed(new WebSocketClosedWithError(status, msg)).discard
false
case _: WebSocketFrame.Close =>
responsesChannel.doneOrClosed().discard
false
case ping: WebSocketFrame.Ping =>
requestsChannel.sendOrClosed(WebSocketFrame.Pong(ping.payload)).discard
// Keep receiving even if pong couldn't be send due to closed request channel. We want to process
// whatever responses there are still coming from the server until it signals the end with a Close frome.
true
case _: WebSocketFrame.Pong =>
// ignore pongs
true
}
catch
case NonFatal(err) =>
responsesChannel.errorOrClosed(err).discard
finally requestsChannel.doneOrClosed().discard
}.discard

fork {
try
repeatWhile {
requestsChannel.receiveOrClosed() match
case closeFrame: WebSocketFrame.Close =>
ws.send(closeFrame)
false
case frame: WebSocketFrame =>
ws.send(frame)
true
case ChannelClosed.Done =>
ws.close()
false
case ChannelClosed.Error(err) =>
// There's no proper "client error" status. Statuses 4000+ are available for custom cases
ws.send(WebSocketFrame.Close(4000, "Client error"))
responsesChannel.doneOrClosed().discard
false
}
catch
case NonFatal(err) =>
// If responses are closed, server finished the communication and we can ignore that send() failed
if (!responsesChannel.isClosedForReceive) requestsChannel.errorOrClosed(err).discard
}.discard

(optionallyConcatenateFrames(responsesChannel, concatenateFragmented), requestsChannel)

final case class WebSocketClosedWithError(statusCode: Int, msg: String)
extends Exception(s"WebSocket closed with status $statusCode: $msg")

private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using
Ox
): Source[WebSocketFrame] =
if doConcatenate then
type Accumulator = Option[Either[Array[Byte], String]]
s.mapStateful(() => None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment =>
(None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment =>
(None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) =>
throw new IllegalStateException(
s"Unexpected WebSocket frame received during concatenation. Frame received: ${f.getClass
.getSimpleName()}, accumulator type: ${acc.map(_.getClass.getSimpleName)}"
)
}.collectAsView { case Some(f: WebSocketFrame) => f }
else s
Loading

0 comments on commit 4b5a428

Please sign in to comment.