From 712f74cabd4c6342d3eb4bc6838560d4a97ff007 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Mon, 17 Feb 2020 23:06:06 +0900 Subject: [PATCH] Add support for WebSockets in Akka HTTP Adapter (#219) * Add support for WebSockets in Akka Adapter * Clean up map on stop --- .../main/scala/caliban/AkkaHttpAdapter.scala | 113 +++++++++++++++++- .../scala/caliban/akkahttp/ExampleApp.scala | 23 +++- .../main/scala/caliban/Http4sAdapter.scala | 4 +- 3 files changed, 128 insertions(+), 12 deletions(-) diff --git a/akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala b/akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala index 2b2ae8851..0438d0815 100644 --- a/akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala +++ b/akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala @@ -1,31 +1,36 @@ package caliban +import scala.concurrent.ExecutionContext import akka.http.scaladsl.model.MediaTypes.`application/json` +import akka.http.scaladsl.model.ws.{ Message, TextMessage } import akka.http.scaladsl.model.{ HttpEntity, HttpResponse, StatusCodes } import akka.http.scaladsl.server.Directives.complete import akka.http.scaladsl.server.{ Route, StandardRoute } +import akka.stream.scaladsl.{ Flow, Sink, Source, SourceQueueWithComplete } +import akka.stream.{ Materializer, OverflowStrategy, QueueOfferResult } +import caliban.ResponseValue.{ ObjectValue, StreamValue } import caliban.Value.NullValue import de.heikoseeberger.akkahttpcirce.FailFastCirceSupport import io.circe.Decoder.Result import io.circe.Json +import io.circe.parser._ import io.circe.syntax._ -import zio.{ Runtime, URIO } - -import scala.concurrent.ExecutionContext +import zio.{ Fiber, IO, Ref, Runtime, Task, URIO } object AkkaHttpAdapter extends FailFastCirceSupport { + private def execute[R, E](interpreter: GraphQLInterpreter[R, E], query: GraphQLRequest): URIO[R, GraphQLResponse[E]] = + interpreter.execute(query.query, query.operationName, query.variables.getOrElse(Map())) + private def executeHttpResponse[R, E]( interpreter: GraphQLInterpreter[R, E], request: GraphQLRequest ): URIO[R, HttpResponse] = - interpreter - .execute(request.query, request.operationName, request.variables.getOrElse(Map())) + execute(interpreter, request) .foldCause(cause => GraphQLResponse(NullValue, cause.defects).asJson, _.asJson) .map(gqlResult => HttpResponse(StatusCodes.OK, entity = HttpEntity(`application/json`, gqlResult.toString()))) def getGraphQLRequest(query: String, op: Option[String], vars: Option[String]): Result[GraphQLRequest] = { - import io.circe.parser._ val variablesJs = vars.flatMap(parse(_).toOption) val fields = List("query" -> Json.fromString(query)) ++ op.map(o => "operationName" -> Json.fromString(o)) ++ @@ -60,4 +65,100 @@ object AkkaHttpAdapter extends FailFastCirceSupport { entity(as[GraphQLRequest]) { completeRequest(interpreter) } } } + + def makeWebSocketService[R, E]( + interpreter: GraphQLInterpreter[R, E] + )(implicit ec: ExecutionContext, runtime: Runtime[R], materializer: Materializer): Route = { + def sendMessage( + sendQueue: SourceQueueWithComplete[Message], + id: String, + data: ResponseValue, + errors: List[E] + ): Task[QueueOfferResult] = + IO.fromFuture( + _ => + sendQueue.offer( + TextMessage( + Json + .obj( + "id" -> Json.fromString(id), + "type" -> Json.fromString("data"), + "payload" -> GraphQLResponse(data, errors).asJson + ) + .noSpaces + ) + ) + ) + + import akka.http.scaladsl.server.Directives._ + + get { + extractUpgradeToWebSocket { upgrade => + val (queue, source) = Source.queue[Message](0, OverflowStrategy.fail).preMaterialize() + val subscriptions = runtime.unsafeRun(Ref.make(Map.empty[String, Fiber[Throwable, Unit]])) + val sink = Sink.foreach[Message] { + case TextMessage.Strict(text) => + val io = for { + msg <- Task.fromEither(decode[Json](text)) + msgType = msg.hcursor.downField("type").success.flatMap(_.value.asString).getOrElse("") + _ <- IO.whenCase(msgType) { + case "connection_init" => + IO.fromFuture(_ => queue.offer(TextMessage("""{"type":"connection_ack"}"""))) + case "connection_terminate" => + IO.effect(queue.complete()) + case "start" => + val payload = msg.hcursor.downField("payload") + val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("") + Task.whenCase(payload.downField("query").success.flatMap(_.value.asString)) { + case Some(query) => + val operationName = payload.downField("operationName").success.flatMap(_.value.asString) + (for { + result <- execute(interpreter, GraphQLRequest(query, operationName, None)) + _ <- result.data match { + case ObjectValue((fieldName, StreamValue(stream)) :: Nil) => + stream.foreach { item => + sendMessage(queue, id, ObjectValue(List(fieldName -> item)), result.errors) + }.fork.flatMap(fiber => subscriptions.update(_.updated(id, fiber))) + case other => + sendMessage(queue, id, other, result.errors) *> IO.fromFuture( + _ => queue.offer(TextMessage(s"""{"type":"complete","id":"$id"}""")) + ) + } + } yield ()).catchAll( + error => + IO.fromFuture( + _ => + queue.offer( + TextMessage( + Json + .obj( + "id" -> Json.fromString(id), + "type" -> Json.fromString("complete"), + "payload" -> Json.fromString(error.toString) + ) + .noSpaces + ) + ) + ) + ) + } + case "stop" => + val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("") + subscriptions + .modify(map => (map.get(id), map - id)) + .flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt }) + } + } yield () + runtime.unsafeRun(io) + case _ => () + } + + val flow = Flow.fromSinkAndSource(sink, source).watchTermination() { (_, f) => + f.onComplete(_ => runtime.unsafeRun(subscriptions.get.flatMap(m => IO.foreach(m.values)(_.interrupt).unit))) + } + + complete(upgrade.handleMessages(flow, subprotocol = Some("graphql-ws"))) + } + } + } } diff --git a/examples/src/main/scala/caliban/akkahttp/ExampleApp.scala b/examples/src/main/scala/caliban/akkahttp/ExampleApp.scala index 646c377c7..5855dcd62 100644 --- a/examples/src/main/scala/caliban/akkahttp/ExampleApp.scala +++ b/examples/src/main/scala/caliban/akkahttp/ExampleApp.scala @@ -1,5 +1,6 @@ package caliban.akkahttp +import scala.language.postfixOps import scala.io.StdIn import akka.actor.ActorSystem import akka.http.scaladsl.Http @@ -8,9 +9,12 @@ import caliban.ExampleData.{ sampleCharacters, Character, CharacterArgs, Charact import caliban.GraphQL.graphQL import caliban.schema.Annotations.{ GQLDeprecated, GQLDescription } import caliban.schema.GenericSchema -import caliban.{ AkkaHttpAdapter, ExampleService, RootResolver } +import caliban.wrappers.ApolloTracing.apolloTracing +import caliban.wrappers.Wrappers._ +import caliban.{ AkkaHttpAdapter, ExampleService, GraphQL, RootResolver } import zio.clock.Clock import zio.console.Console +import zio.duration._ import zio.stream.ZStream import zio.{ DefaultRuntime, URIO } @@ -34,9 +38,7 @@ object ExampleApp extends App with GenericSchema[Console with Clock] { case class Mutations(deleteCharacter: CharacterArgs => URIO[Console, Boolean]) case class Subscriptions(characterDeleted: ZStream[Console, Nothing, String]) - val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters)) - - val interpreter = + def makeApi(service: ExampleService): GraphQL[Console with Clock] = graphQL( RootResolver( Queries( @@ -46,7 +48,16 @@ object ExampleApp extends App with GenericSchema[Console with Clock] { Mutations(args => service.deleteCharacter(args.name)), Subscriptions(service.deletedEvents) ) - ).interpreter + ) @@ + maxFields(200) @@ // query analyzer that limit query fields + maxDepth(30) @@ // query analyzer that limit query depth + timeout(3 seconds) @@ // wrapper that fails slow queries + printSlowQueries(500 millis) @@ // wrapper that logs slow queries + apolloTracing // wrapper for https://github.com/apollographql/apollo-tracing + + val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters)) + + val interpreter = makeApi(service).interpreter /** * curl -X POST \ @@ -60,6 +71,8 @@ object ExampleApp extends App with GenericSchema[Console with Clock] { val route = path("api" / "graphql") { AkkaHttpAdapter.makeHttpService(interpreter) + } ~ path("ws" / "graphql") { + AkkaHttpAdapter.makeWebSocketService(interpreter) } ~ path("graphiql") { getFromResource("graphiql.html") } diff --git a/http4s/src/main/scala/caliban/Http4sAdapter.scala b/http4s/src/main/scala/caliban/Http4sAdapter.scala index 492b48373..b3a618be9 100644 --- a/http4s/src/main/scala/caliban/Http4sAdapter.scala +++ b/http4s/src/main/scala/caliban/Http4sAdapter.scala @@ -152,7 +152,9 @@ object Http4sAdapter { } case "stop" => val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("") - subscriptions.get.flatMap(map => IO.whenCase(map.get(id)) { case Some(fiber) => fiber.interrupt }) + subscriptions + .modify(map => (map.get(id), map - id)) + .flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt }) } } yield () }