From 07f12810cdf671c3055e5fd083a8a6b4ae2e97be Mon Sep 17 00:00:00 2001 From: Sven Date: Mon, 21 Aug 2023 16:07:54 +0200 Subject: [PATCH 1/3] Always return an id upon errors --- .../main/scala/caliban/interop/tapir/ws/Protocol.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala index 2394a7c86c..b5b20dba9c 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala @@ -9,6 +9,7 @@ import caliban.interop.tapir.WebSocketHooks import zio.stm.TMap import zio.stream.{ UStream, ZStream } import zio.{ Duration, Promise, Queue, Ref, Schedule, UIO, URIO, ZIO } +import java.util.UUID sealed trait Protocol { def name: String @@ -49,15 +50,20 @@ object Protocol { override def complete(id: String): GraphQLWSOutput = GraphQLWSOutput(Ops.Complete, Some(id), None) - override def error[E](id: Option[String], e: E): GraphQLWSOutput = + override def error[E](id: Option[String], e: E): GraphQLWSOutput = { + val outputId = id match { + case None => Some(UUID.randomUUID().toString) + case Some(_) => id + } GraphQLWSOutput( Ops.Error, - id, + outputId, Some(ResponseValue.ListValue(List(e match { case e: CalibanError => e.toResponseValue case e => StringValue(e.toString) }))) ) + } } override def make[R, E]( From 483d4a1946cbbd9c4b3b6a538716b0d77bc48a43 Mon Sep 17 00:00:00 2001 From: Sven Date: Tue, 22 Aug 2023 13:35:25 +0200 Subject: [PATCH 2/3] According to the server implementation in graphql-transport-ws always close the socket with 4403 on error --- .../src/main/scala/caliban/interop/tapir/ws/Protocol.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala index b5b20dba9c..b19b6a82d7 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala @@ -82,7 +82,7 @@ object Protocol { case GraphQLWSInput(Ops.ConnectionInit, id, payload) => val before = ZIO.whenCase((webSocketHooks.beforeInit, payload)) { case (Some(beforeInit), Some(payload)) => - beforeInit(payload).catchAll(e => output.offer(Right(handler.error(id, e)))) + beforeInit(payload).catchAll(_ => output.offer(Left(GraphQLWSClose(4403, "Forbidden")))) } val ackPayload = webSocketHooks.onAck.fold[URIO[R, Option[ResponseValue]]](ZIO.none)(_.option) val response = From 8af118f2a43f7562763e6e1324741668361efa2c Mon Sep 17 00:00:00 2001 From: Sven Date: Fri, 25 Aug 2023 09:34:25 +0200 Subject: [PATCH 3/3] Moved over to using zio.Random for id generation --- .../caliban/interop/tapir/ws/Protocol.scala | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala index b19b6a82d7..48b00d55cd 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala @@ -8,8 +8,7 @@ import caliban.interop.tapir.TapirAdapter.CalibanPipe import caliban.interop.tapir.WebSocketHooks import zio.stm.TMap import zio.stream.{ UStream, ZStream } -import zio.{ Duration, Promise, Queue, Ref, Schedule, UIO, URIO, ZIO } -import java.util.UUID +import zio.{ Duration, Promise, Queue, Random, Ref, Schedule, UIO, URIO, ZIO } sealed trait Protocol { def name: String @@ -50,20 +49,15 @@ object Protocol { override def complete(id: String): GraphQLWSOutput = GraphQLWSOutput(Ops.Complete, Some(id), None) - override def error[E](id: Option[String], e: E): GraphQLWSOutput = { - val outputId = id match { - case None => Some(UUID.randomUUID().toString) - case Some(_) => id - } + override def error[E](id: Option[String], e: E): GraphQLWSOutput = GraphQLWSOutput( Ops.Error, - outputId, + id, Some(ResponseValue.ListValue(List(e match { case e: CalibanError => e.toResponseValue case e => StringValue(e.toString) }))) ) - } } override def make[R, E]( @@ -92,7 +86,7 @@ object Protocol { afterInit .catchAllCause(cause => ZIO.foreachDiscard(cause.failureOption)(e => - output.offer(Right(handler.error(id, e))) + generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e)))) ) *> output.offer(Left(GraphQLWSClose(4401, "Unauthorized"))) ) .fork @@ -101,14 +95,20 @@ object Protocol { before *> response *> ka *> after case GraphQLWSInput(Ops.Pong, id, payload) => ZIO.whenCase(webSocketHooks.onPong -> payload) { case (Some(onPong), Some(payload)) => - onPong(payload).catchAll(e => output.offer(Right(handler.error(id, e)))) + onPong(payload).catchAll(e => + generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e)))) + ) } case GraphQLWSInput(Ops.Ping, id, payload) => def sendPong(p: Option[ResponseValue]) = output.offer(Right(GraphQLWSOutput(Ops.Pong, id, p))) webSocketHooks.onPing match { case Some(onPing) => - onPing(payload).flatMap(sendPong).catchAll(e => output.offer(Right(handler.error(id, e)))) + onPing(payload) + .flatMap(sendPong) + .catchAll(e => + generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e)))) + ) case _ => sendPong(None) } case GraphQLWSInput(Ops.Subscribe, Some(id), payload) => @@ -137,7 +137,8 @@ object Protocol { .unit ) - case None => output.offer(Right(connectionError)) + case None => + generateId(None).flatMap(uuid => output.offer(Right(connectionError(uuid)))) } ZIO.ifZIO(ack.get)(continue, output.offer(Left(GraphQLWSClose(4401, "Unauthorized")))) @@ -146,7 +147,9 @@ object Protocol { case GraphQLWSInput(unsupported, _, _) => output.offer(Left(GraphQLWSClose(4400, s"Unsupported operation: $unsupported"))) }.runDrain.interruptible - .catchAll(_ => output.offer(Right(connectionError))) + .catchAll(_ => + generateId(None).flatMap(uuid => output.offer(Right(connectionError(Some(uuid.toString()))))) + ) .ensuring(subscriptions.untrackAll) .provideEnvironment(env) .forkScoped @@ -154,10 +157,16 @@ object Protocol { } } yield pipe - private val connectionError: GraphQLWSOutput = GraphQLWSOutput(Ops.Error, None, None) + private def connectionError(id: Option[String]): GraphQLWSOutput = GraphQLWSOutput(Ops.Error, id, None) private def connectionAck(payload: Option[ResponseValue]): GraphQLWSOutput = GraphQLWSOutput(Ops.ConnectionAck, None, payload) + private def generateId(id: Option[String]): ZIO[Any, Nothing, Option[String]] = + id match { + case Some(_) => ZIO.succeed(id) + case None => Random.nextUUID.map(uuid => Some(uuid.toString)) + } + private def ping(keepAlive: Option[Duration]): UStream[Either[Nothing, GraphQLWSOutput]] = keepAlive match { case None => ZStream.empty