Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WebSocket frame concatenation for Netty #3801

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ private[http4s] object Http4sWebSockets {
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed this seems rather basic :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if http4s concatenates frames by itself, shouldn't we remove this? it's dead code anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let me double check though, maybe there's some kind of switch on the Blaze/Ember level to control this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, turns out that:

  1. Http4s can be configured to not concatenate fragmented frames and not ignore pings/pongs on the level of WebSocketBuilder2:
      .withHttpWebSocketApp(wsb => Router("/" -> wsRoutes(wsb.withDefragment(false).withFilterPingPongs(false)).orNotFound)
  1. This has no effect if used with Blaze, which always automatically defragments and filters pings pongs
  2. It works with Ember though.

I'll remove these stages from Http4sWebSockets and add some documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can keep our own defragmentation and force wsb.withDefragment(false) to make this Tapir-native switch work on Ember by forcing passing through our concatenation 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ... if it's optional, let's keep the code, and let the user of http4s decide. So I guess the code is good as-is

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has several bugs actually, as I discovered when I switched from Blaze to Ember in tests. Fixes are on the way in a separate PR then ;)

case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP](

private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

s.mapAccumulate(None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collect { case (_, Some(f)) => f }
s.mapAccumulate(None: Accumulator)(accumulateFrameState).collect { case (_, Some(f)) => f }
} else s
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,23 @@ object WebSocketFrameConverters {
case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) =>
new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload))
}

type Accumulator = Option[Either[Array[Byte], String]]
val accumulateFrameState: (Accumulator, WebSocketFrame) => (Accumulator, Option[WebSocketFrame]) = {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Binary) if f.finalFragment =>
// Netty's ContinuationFrame is translated to Binary, so we need to handle a Binary frame received after accumulating Text
(None, Some(WebSocketFrame.Text(payload = acc + new String(f.payload), finalFragment = true, rsv = f.rsv)))
case (Some(Right(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Right(acc + new String(f.payload))), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,5 @@ private[sync] object OxSourceWebSocketProcessor:

private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] =
if doConcatenate then
type Accumulator = Option[Either[Array[Byte], String]]
s.mapStateful(() => None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collectAsView { case Some(f: WebSocketFrame) => f }
s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
else s
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
failingPipe: Boolean,
handlePong: Boolean,
// Disabled for eaxmple for vert.x, which sometimes drops connection without returning Close
expectCloseResponse: Boolean = true
expectCloseResponse: Boolean = true,
frameConcatenation: Boolean = true
)(implicit
m: MonadError[F]
) extends EitherValues {
Expand Down Expand Up @@ -244,7 +245,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
response2.body shouldBe Right("echo: testOk")
}
}
) ++ autoPingTests ++ failingPipeTests ++ handlePongTests
) ++ autoPingTests ++ failingPipeTests ++ handlePongTests ++ frameConcatenationTests

val autoPingTests =
if (autoPing)
Expand Down Expand Up @@ -314,6 +315,53 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
)
else List.empty

val frameConcatenationTests = if (frameConcatenation) List(
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented text frames"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None))
_ <- ws.sendText("f2")
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: f1f2")) }
},
testServer(
endpoint.out(
webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented binary frames"
)((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None))
_ <- ws.sendBinary("frame2-bytes".getBytes())
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) }
}
) else Nil

val handlePongTests =
if (handlePong)
List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class CatsVertxServerTest extends TestSuite {
autoPing = false,
failingPipe = false,
handlePong = true,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ object websocket {
(None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment =>
(None, Some(f))
case (None, f: WebSocketFrame.Text) =>
(Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) =>
(Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment =>
(None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class VertxServerTest extends TestSuite {
autoPing = false,
failingPipe = false,
handlePong = false,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f)
override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class ZioVertxServerTest extends TestSuite with OptionValues {
autoPing = true,
failingPipe = false,
handlePong = false,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,13 @@ object ZioWebSockets {
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: SttpWebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: SttpWebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(Some(Right(acc + f.payload)), None)

case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ class ZioHttpServerTest extends TestSuite {
ZioStreams,
autoPing = true,
failingPipe = false,
handlePong = false
handlePong = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Loading