Skip to content

Commit

Permalink
Add support for SSE-based subscriptions for caliban-quick (#2141)
Browse files Browse the repository at this point in the history
* Add support for SSE-based subscriptions in the quick adapter

* Cleanup `AcceptsGqlEncodings`

* One more cleanup

* Always send a 200 code for SSE (Tapir)

* Unify SSE response handling
  • Loading branch information
kyri-petrou authored Feb 29, 2024
1 parent e8a8cc3 commit 3d1ad9a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 70 deletions.
29 changes: 17 additions & 12 deletions adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package caliban

import caliban.Configurator.ExecutionConfiguration
import caliban.HttpUtils.DeferMultipart
import caliban.ResponseValue.StreamValue
import caliban.HttpUtils.{ DeferMultipart, ServerSentEvents }
import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.Value.NullValue
import caliban.interop.jsoniter.ValueJsoniter
import caliban.uploads.{ FileMeta, GraphQLUploadRequest, Uploads }
import caliban.wrappers.Caching
Expand All @@ -12,7 +13,7 @@ import zio._
import zio.http.Header.ContentType
import zio.http._
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.stream.ZStream
import zio.stream.{ UStream, ZStream }

import java.nio.charset.StandardCharsets.UTF_8
import scala.util.control.NonFatal
Expand Down Expand Up @@ -148,13 +149,7 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A
cacheDirective.fold(headers)(headers.addHeader(Header.CacheControl.name, _))

private def transformResponse(httpReq: Request, resp: GraphQLResponse[Any])(implicit trace: Trace): Response = {

val acceptsGqlJson: Boolean =
httpReq.headers.get(Header.Accept.name).exists { h =>
// Better performance than having to parse the Accept header
h.length >= 33 && h.toLowerCase.contains("application/graphql-response+json")
}

val accepts = new HttpUtils.AcceptsGqlEncodings(httpReq.headers.get(Header.Accept.name))
val cacheDirective = HttpUtils.computeCacheDirective(resp.extensions)

resp match {
Expand All @@ -164,7 +159,9 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A
headers = responseHeaders(ContentTypeMultipart, None),
body = Body.fromStream(encodeMultipartMixedResponse(resp, stream))
)
case resp if acceptsGqlJson =>
case resp if accepts.serverSentEvents =>
Response.fromServerSentEvents(encodeTextEventStream(resp))
case resp if accepts.graphQLJson =>
Response(
status = resp.errors.collectFirst { case _: CalibanError.ParsingError | _: CalibanError.ValidationError =>
Status.BadRequest
Expand Down Expand Up @@ -205,11 +202,17 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A
.mapConcatChunk(Chunk.fromArray)
}

private def encodeTextEventStream(resp: GraphQLResponse[Any])(implicit trace: Trace): UStream[ServerSentEvent] =
ServerSentEvents.transformResponse(
resp,
v => ServerSentEvent(writeToString(v), Some("next")),
CompleteSse
)

private def isFtv1Request(req: Request) =
req.headers
.get(GraphQLRequest.`apollo-federation-include-trace`)
.exists(_.equalsIgnoreCase(GraphQLRequest.ftv1))

}

object QuickRequestHandler {
Expand All @@ -225,6 +228,8 @@ object QuickRequestHandler {
private val ContentTypeMultipart =
Headers(Header.ContentType(MediaType.multipart.mixed.copy(parameters = DeferMultipart.DeferHeaderParams)).untyped)

private val CompleteSse = ServerSentEvent("", Some("complete"))

private val BodyDecodeErrorResponse =
badRequest("Failed to decode json body")

Expand Down
3 changes: 1 addition & 2 deletions adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ object QuickAdapterSpec extends ZIOSpecDefault {
"QuickAdapterSpec",
uri"http://localhost:8090/api/graphql",
wsUri = None,
uploadUri = Some(uri"http://localhost:8090/upload/graphql"),
sseSupport = None
uploadUri = Some(uri"http://localhost:8090/upload/graphql")
)
suite.provideShared(
apiLayer,
Expand Down
44 changes: 40 additions & 4 deletions core/src/main/scala/caliban/HttpUtils.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package caliban

import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.Value.NullValue
import caliban.wrappers.Caching
import zio.stream.{ ZChannel, ZPipeline }
import zio.{ Cause, Chunk }
import zio.stream.{ UStream, ZChannel, ZPipeline, ZStream }
import zio.{ Cause, Chunk, Trace }

private[caliban] object HttpUtils {

Expand Down Expand Up @@ -38,10 +40,44 @@ private[caliban] object HttpUtils {
}
}

object ServerSentEvents {

def transformResponse[Sse](
resp: GraphQLResponse[Any],
toSse: ResponseValue => Sse,
done: Sse
)(implicit trace: Trace): UStream[Sse] =
(resp.data match {
case ObjectValue((fieldName, StreamValue(stream)) :: Nil) =>
// Report errors in an initial event sent immediately
val init =
if (resp.errors.isEmpty) ZStream.empty else ZStream.succeed(GraphQLResponse(NullValue, resp.errors))
init ++ stream.either.map {
case Right(r) => GraphQLResponse(ObjectValue(List(fieldName -> r)), Nil)
case Left(err) => GraphQLResponse(ObjectValue(List(fieldName -> NullValue)), List(err))
}
case _ => ZStream.succeed(resp)
}).map(v => toSse(v.toResponseValue)) ++ ZStream.succeed(done)
}

def computeCacheDirective(extensions: Option[ResponseValue.ObjectValue]): Option[String] =
extensions
.flatMap(_.fields.collectFirst { case (Caching.DirectiveName, ResponseValue.ObjectValue(fields)) =>
fields.collectFirst { case ("httpHeader", Value.StringValue(cacheHeader)) => cacheHeader }
})
.flatten
}.flatten)

final class AcceptsGqlEncodings(header0: Option[String]) {
private val isEmpty = header0.isEmpty
private val length = if (isEmpty) 0 else header0.get.length
private lazy val header = if (isEmpty) "" else header0.get.toLowerCase

/**
* NOTE: From 1st January 2025 this should be changed to `true` as the default
*
* @see [[https://graphql.github.io/graphql-over-http/draft/#sec-Legacy-watershed]]
*/
def graphQLJson: Boolean = length >= 33 && header.contains("application/graphql-response+json")

def serverSentEvents: Boolean = length >= 17 && header.contains("text/event-stream")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,7 @@ object TapirAdapter {
streamConstructor: StreamConstructor[BS],
responseCodec: JsonCodec[ResponseValue]
): (MediaType, StatusCode, Option[String], CalibanBody[BS]) = {

/**
* NOTE: From 1st January 2025 this logic should be changed to use `application/graphql-response+json` as the
* default content-type when the client does not specify an accept header.
*
* @see [[https://graphql.github.io/graphql-over-http/draft/#sec-Legacy-watershed]]
*/
def accepts = request.acceptsContentTypes
.fold(
_ => None,
_.find {
case ContentTypeRange("application", "graphql-response+json", _, _) => true
case ContentTypeRange("text", "event-stream", _, _) => true
case _ => false
}
)
.map(ct => MediaType(ct.mainType, ct.subType))
.getOrElse(MediaType.ApplicationJson)
val accepts = new HttpUtils.AcceptsGqlEncodings(request.header(HeaderNames.Accept))

response match {
case resp @ GraphQLResponse(StreamValue(stream), _, _, _) =>
Expand All @@ -133,7 +116,7 @@ object TapirAdapter {
None,
encodeMultipartMixedResponse(resp, stream)
)
case resp if accepts == GraphqlResponseJson.mediaType =>
case resp if accepts.graphQLJson =>
val code =
response.errors.collectFirst { case _: CalibanError.ParsingError | _: CalibanError.ValidationError =>
StatusCode.BadRequest
Expand All @@ -149,12 +132,10 @@ object TapirAdapter {
excludeExtensions = cacheDirective.map(_ => Set(Caching.DirectiveName))
)
)
case resp if accepts == GraphqlServerSentEvent.mediaType =>
val code = response.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => StatusCode.BadRequest }
.getOrElse(StatusCode.Ok)
case resp if accepts.serverSentEvents =>
(
MediaType.TextEventStream,
code,
StatusCode.Ok,
None,
encodeTextEventStreamResponse(resp)
)
Expand Down Expand Up @@ -225,31 +206,11 @@ object TapirAdapter {
private def encodeTextEventStreamResponse[E, BS](
resp: GraphQLResponse[E]
)(implicit streamConstructor: StreamConstructor[BS], responseCodec: JsonCodec[ResponseValue]): CalibanBody[BS] = {
val response: ZStream[Any, Throwable, ServerSentEvent] = (resp.data match {
case ObjectValue(fields) =>
fields.foldLeft(ZStream.empty: ZStream[Any, Throwable, ServerSentEvent]) { case (_, v) =>
v match {
case (fieldName, StreamValue(stream)) =>
stream.map { r =>
ServerSentEvent(
Some(
responseCodec.encode(
GraphQLResponse(
ObjectValue(List(fieldName -> r)),
resp.errors
).toResponseValue
)
),
Some("next")
)
}
case _ =>
ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next")))
}
}
case _ =>
ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next")))
}) ++ ZStream.succeed(ServerSentEvent(None, Some("complete")))
val response = HttpUtils.ServerSentEvents.transformResponse(
resp,
v => ServerSentEvent(Some(responseCodec.encode(v)), Some("next")),
ServerSentEvent(None, Some("complete"))
)
Right(streamConstructor(ZioServerSentEvents.serialiseSSEToBytes(response)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object TapirAdapterSpec {
httpUri: Uri,
uploadUri: Option[Uri] = None,
wsUri: Option[Uri] = None,
sseSupport: Option[Boolean] = Some(true)
sseSupport: Boolean = true
)(implicit
requestCodec: JsonCodec[GraphQLRequest],
responseCodec: JsonCodec[GraphQLResponse[CalibanError]],
Expand Down Expand Up @@ -269,8 +269,8 @@ object TapirAdapterSpec {
}
)
),
sseSupport.map(_ =>
suite("SSE")(
Some(
suite("server-sent events")(
test("TextEventStream") {
for {
res <- runSSERequest(
Expand All @@ -290,7 +290,7 @@ object TapirAdapterSpec {
)
)
} @@ TestAspect.timeout(10.seconds)
)
).when(sseSupport)
),
runUpload.map(runUpload =>
suite("uploads")(
Expand Down

0 comments on commit 3d1ad9a

Please sign in to comment.