From 3d1ad9abc8c94a68c17984f6c33d81f14c52f7ca Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Thu, 29 Feb 2024 08:03:54 -0400 Subject: [PATCH] Add support for SSE-based subscriptions for `caliban-quick` (#2141) * Add support for SSE-based subscriptions in the quick adapter * Cleanup `AcceptsGqlEncodings` * One more cleanup * Always send a 200 code for SSE (Tapir) * Unify SSE response handling --- .../scala/caliban/QuickRequestHandler.scala | 29 ++++++---- .../test/scala/caliban/QuickAdapterSpec.scala | 3 +- core/src/main/scala/caliban/HttpUtils.scala | 44 ++++++++++++-- .../caliban/interop/tapir/TapirAdapter.scala | 57 +++---------------- .../interop/tapir/TapirAdapterSpec.scala | 8 +-- 5 files changed, 71 insertions(+), 70 deletions(-) diff --git a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala index 12da1a14d..b4442088d 100644 --- a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala +++ b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala @@ -1,8 +1,9 @@ package caliban import caliban.Configurator.ExecutionConfiguration -import caliban.HttpUtils.DeferMultipart -import caliban.ResponseValue.StreamValue +import caliban.HttpUtils.{ DeferMultipart, ServerSentEvents } +import caliban.ResponseValue.{ ObjectValue, StreamValue } +import caliban.Value.NullValue import caliban.interop.jsoniter.ValueJsoniter import caliban.uploads.{ FileMeta, GraphQLUploadRequest, Uploads } import caliban.wrappers.Caching @@ -12,7 +13,7 @@ import zio._ import zio.http.Header.ContentType import zio.http._ import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.stream.ZStream +import zio.stream.{ UStream, ZStream } import java.nio.charset.StandardCharsets.UTF_8 import scala.util.control.NonFatal @@ -148,13 +149,7 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A cacheDirective.fold(headers)(headers.addHeader(Header.CacheControl.name, _)) private def transformResponse(httpReq: Request, resp: GraphQLResponse[Any])(implicit trace: Trace): Response = { - - val acceptsGqlJson: Boolean = - httpReq.headers.get(Header.Accept.name).exists { h => - // Better performance than having to parse the Accept header - h.length >= 33 && h.toLowerCase.contains("application/graphql-response+json") - } - + val accepts = new HttpUtils.AcceptsGqlEncodings(httpReq.headers.get(Header.Accept.name)) val cacheDirective = HttpUtils.computeCacheDirective(resp.extensions) resp match { @@ -164,7 +159,9 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A headers = responseHeaders(ContentTypeMultipart, None), body = Body.fromStream(encodeMultipartMixedResponse(resp, stream)) ) - case resp if acceptsGqlJson => + case resp if accepts.serverSentEvents => + Response.fromServerSentEvents(encodeTextEventStream(resp)) + case resp if accepts.graphQLJson => Response( status = resp.errors.collectFirst { case _: CalibanError.ParsingError | _: CalibanError.ValidationError => Status.BadRequest @@ -205,11 +202,17 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A .mapConcatChunk(Chunk.fromArray) } + private def encodeTextEventStream(resp: GraphQLResponse[Any])(implicit trace: Trace): UStream[ServerSentEvent] = + ServerSentEvents.transformResponse( + resp, + v => ServerSentEvent(writeToString(v), Some("next")), + CompleteSse + ) + private def isFtv1Request(req: Request) = req.headers .get(GraphQLRequest.`apollo-federation-include-trace`) .exists(_.equalsIgnoreCase(GraphQLRequest.ftv1)) - } object QuickRequestHandler { @@ -225,6 +228,8 @@ object QuickRequestHandler { private val ContentTypeMultipart = Headers(Header.ContentType(MediaType.multipart.mixed.copy(parameters = DeferMultipart.DeferHeaderParams)).untyped) + private val CompleteSse = ServerSentEvent("", Some("complete")) + private val BodyDecodeErrorResponse = badRequest("Failed to decode json body") diff --git a/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala b/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala index e0698fb6d..65c5413bf 100644 --- a/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala +++ b/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala @@ -36,8 +36,7 @@ object QuickAdapterSpec extends ZIOSpecDefault { "QuickAdapterSpec", uri"http://localhost:8090/api/graphql", wsUri = None, - uploadUri = Some(uri"http://localhost:8090/upload/graphql"), - sseSupport = None + uploadUri = Some(uri"http://localhost:8090/upload/graphql") ) suite.provideShared( apiLayer, diff --git a/core/src/main/scala/caliban/HttpUtils.scala b/core/src/main/scala/caliban/HttpUtils.scala index 455e101ea..5fead8d84 100644 --- a/core/src/main/scala/caliban/HttpUtils.scala +++ b/core/src/main/scala/caliban/HttpUtils.scala @@ -1,8 +1,10 @@ package caliban +import caliban.ResponseValue.{ ObjectValue, StreamValue } +import caliban.Value.NullValue import caliban.wrappers.Caching -import zio.stream.{ ZChannel, ZPipeline } -import zio.{ Cause, Chunk } +import zio.stream.{ UStream, ZChannel, ZPipeline, ZStream } +import zio.{ Cause, Chunk, Trace } private[caliban] object HttpUtils { @@ -38,10 +40,44 @@ private[caliban] object HttpUtils { } } + object ServerSentEvents { + + def transformResponse[Sse]( + resp: GraphQLResponse[Any], + toSse: ResponseValue => Sse, + done: Sse + )(implicit trace: Trace): UStream[Sse] = + (resp.data match { + case ObjectValue((fieldName, StreamValue(stream)) :: Nil) => + // Report errors in an initial event sent immediately + val init = + if (resp.errors.isEmpty) ZStream.empty else ZStream.succeed(GraphQLResponse(NullValue, resp.errors)) + init ++ stream.either.map { + case Right(r) => GraphQLResponse(ObjectValue(List(fieldName -> r)), Nil) + case Left(err) => GraphQLResponse(ObjectValue(List(fieldName -> NullValue)), List(err)) + } + case _ => ZStream.succeed(resp) + }).map(v => toSse(v.toResponseValue)) ++ ZStream.succeed(done) + } + def computeCacheDirective(extensions: Option[ResponseValue.ObjectValue]): Option[String] = extensions .flatMap(_.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) => fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader } - }) - .flatten + }.flatten) + + final class AcceptsGqlEncodings(header0: Option[String]) { + private val isEmpty = header0.isEmpty + private val length = if (isEmpty) 0 else header0.get.length + private lazy val header = if (isEmpty) "" else header0.get.toLowerCase + + /** + * NOTE: From 1st January 2025 this should be changed to `true` as the default + * + * @see [[https://graphql.github.io/graphql-over-http/draft/#sec-Legacy-watershed]] + */ + def graphQLJson: Boolean = length >= 33 && header.contains("application/graphql-response+json") + + def serverSentEvents: Boolean = length >= 17 && header.contains("text/event-stream") + } } diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala index b46273eb0..a822e8c19 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala @@ -106,24 +106,7 @@ object TapirAdapter { streamConstructor: StreamConstructor[BS], responseCodec: JsonCodec[ResponseValue] ): (MediaType, StatusCode, Option[String], CalibanBody[BS]) = { - - /** - * NOTE: From 1st January 2025 this logic should be changed to use `application/graphql-response+json` as the - * default content-type when the client does not specify an accept header. - * - * @see [[https://graphql.github.io/graphql-over-http/draft/#sec-Legacy-watershed]] - */ - def accepts = request.acceptsContentTypes - .fold( - _ => None, - _.find { - case ContentTypeRange("application", "graphql-response+json", _, _) => true - case ContentTypeRange("text", "event-stream", _, _) => true - case _ => false - } - ) - .map(ct => MediaType(ct.mainType, ct.subType)) - .getOrElse(MediaType.ApplicationJson) + val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept)) response match { case resp @ GraphQLResponse(StreamValue(stream), _, _, _) => @@ -133,7 +116,7 @@ object TapirAdapter { None, encodeMultipartMixedResponse(resp, stream) ) - case resp if accepts == GraphqlResponseJson.mediaType => + case resp if accepts.graphQLJson => val code = response.errors.collectFirst { case _: CalibanError.ParsingError | _: CalibanError.ValidationError => StatusCode.BadRequest @@ -149,12 +132,10 @@ object TapirAdapter { excludeExtensions = cacheDirective.map(_ => Set(Caching.DirectiveName)) ) ) - case resp if accepts == GraphqlServerSentEvent.mediaType => - val code = response.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => StatusCode.BadRequest } - .getOrElse(StatusCode.Ok) + case resp if accepts.serverSentEvents => ( MediaType.TextEventStream, - code, + StatusCode.Ok, None, encodeTextEventStreamResponse(resp) ) @@ -225,31 +206,11 @@ object TapirAdapter { private def encodeTextEventStreamResponse[E, BS]( resp: GraphQLResponse[E] )(implicit streamConstructor: StreamConstructor[BS], responseCodec: JsonCodec[ResponseValue]): CalibanBody[BS] = { - val response: ZStream[Any, Throwable, ServerSentEvent] = (resp.data match { - case ObjectValue(fields) => - fields.foldLeft(ZStream.empty: ZStream[Any, Throwable, ServerSentEvent]) { case (_, v) => - v match { - case (fieldName, StreamValue(stream)) => - stream.map { r => - ServerSentEvent( - Some( - responseCodec.encode( - GraphQLResponse( - ObjectValue(List(fieldName -> r)), - resp.errors - ).toResponseValue - ) - ), - Some("next") - ) - } - case _ => - ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next"))) - } - } - case _ => - ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next"))) - }) ++ ZStream.succeed(ServerSentEvent(None, Some("complete"))) + val response = HttpUtils.ServerSentEvents.transformResponse( + resp, + v => ServerSentEvent(Some(responseCodec.encode(v)), Some("next")), + ServerSentEvent(None, Some("complete")) + ) Right(streamConstructor(ZioServerSentEvents.serialiseSSEToBytes(response))) } diff --git a/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala b/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala index 7013d5347..d78f28007 100644 --- a/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala +++ b/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala @@ -53,7 +53,7 @@ object TapirAdapterSpec { httpUri: Uri, uploadUri: Option[Uri] = None, wsUri: Option[Uri] = None, - sseSupport: Option[Boolean] = Some(true) + sseSupport: Boolean = true )(implicit requestCodec: JsonCodec[GraphQLRequest], responseCodec: JsonCodec[GraphQLResponse[CalibanError]], @@ -269,8 +269,8 @@ object TapirAdapterSpec { } ) ), - sseSupport.map(_ => - suite("SSE")( + Some( + suite("server-sent events")( test("TextEventStream") { for { res <- runSSERequest( @@ -290,7 +290,7 @@ object TapirAdapterSpec { ) ) } @@ TestAspect.timeout(10.seconds) - ) + ).when(sseSupport) ), runUpload.map(runUpload => suite("uploads")(