diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala index b19b6a82d7..506a242f13 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala @@ -8,8 +8,7 @@ import caliban.interop.tapir.TapirAdapter.CalibanPipe import caliban.interop.tapir.WebSocketHooks import zio.stm.TMap import zio.stream.{ UStream, ZStream } -import zio.{ Duration, Promise, Queue, Ref, Schedule, UIO, URIO, ZIO } -import java.util.UUID +import zio.{ Duration, Promise, Queue, Random, Ref, Schedule, UIO, URIO, ZIO } sealed trait Protocol { def name: String @@ -50,20 +49,15 @@ object Protocol { override def complete(id: String): GraphQLWSOutput = GraphQLWSOutput(Ops.Complete, Some(id), None) - override def error[E](id: Option[String], e: E): GraphQLWSOutput = { - val outputId = id match { - case None => Some(UUID.randomUUID().toString) - case Some(_) => id - } + override def error[E](id: Option[String], e: E): GraphQLWSOutput = GraphQLWSOutput( Ops.Error, - outputId, + id, Some(ResponseValue.ListValue(List(e match { case e: CalibanError => e.toResponseValue case e => StringValue(e.toString) }))) ) - } } override def make[R, E]( @@ -92,7 +86,7 @@ object Protocol { afterInit .catchAllCause(cause => ZIO.foreachDiscard(cause.failureOption)(e => - output.offer(Right(handler.error(id, e))) + generateId(id).map(uuid => output.offer(Right(handler.error(uuid, e)))) ) *> output.offer(Left(GraphQLWSClose(4401, "Unauthorized"))) ) .fork @@ -101,14 +95,18 @@ object Protocol { before *> response *> ka *> after case GraphQLWSInput(Ops.Pong, id, payload) => ZIO.whenCase(webSocketHooks.onPong -> payload) { case (Some(onPong), Some(payload)) => - onPong(payload).catchAll(e => output.offer(Right(handler.error(id, e)))) + onPong(payload).catchAll(e => + generateId(id).map(uuid => output.offer(Right(handler.error(uuid, e)))) + ) } case GraphQLWSInput(Ops.Ping, id, payload) => def sendPong(p: Option[ResponseValue]) = output.offer(Right(GraphQLWSOutput(Ops.Pong, id, p))) webSocketHooks.onPing match { case Some(onPing) => - onPing(payload).flatMap(sendPong).catchAll(e => output.offer(Right(handler.error(id, e)))) + onPing(payload) + .flatMap(sendPong) + .catchAll(e => generateId(id).map(uuid => output.offer(Right(handler.error(uuid, e))))) case _ => sendPong(None) } case GraphQLWSInput(Ops.Subscribe, Some(id), payload) => @@ -137,7 +135,8 @@ object Protocol { .unit ) - case None => output.offer(Right(connectionError)) + case None => + generateId(None).map(uuid => output.offer(Right(connectionError(uuid)))) } ZIO.ifZIO(ack.get)(continue, output.offer(Left(GraphQLWSClose(4401, "Unauthorized")))) @@ -146,7 +145,9 @@ object Protocol { case GraphQLWSInput(unsupported, _, _) => output.offer(Left(GraphQLWSClose(4400, s"Unsupported operation: $unsupported"))) }.runDrain.interruptible - .catchAll(_ => output.offer(Right(connectionError))) + .catchAll(_ => + generateId(None).map(uuid => output.offer(Right(connectionError(Some(uuid.toString()))))) + ) .ensuring(subscriptions.untrackAll) .provideEnvironment(env) .forkScoped @@ -154,10 +155,16 @@ object Protocol { } } yield pipe - private val connectionError: GraphQLWSOutput = GraphQLWSOutput(Ops.Error, None, None) + private def connectionError(id: Option[String]): GraphQLWSOutput = GraphQLWSOutput(Ops.Error, id, None) private def connectionAck(payload: Option[ResponseValue]): GraphQLWSOutput = GraphQLWSOutput(Ops.ConnectionAck, None, payload) + private def generateId(id: Option[String]): ZIO[Any, Nothing, Option[String]] = + id match { + case Some(_) => ZIO.succeed(id) + case None => Random.nextUUID.map(uuid => Some(uuid.toString)) + } + private def ping(keepAlive: Option[Duration]): UStream[Either[Nothing, GraphQLWSOutput]] = keepAlive match { case None => ZStream.empty