Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WebSockets in Akka HTTP Adapter #219

Merged
merged 2 commits into from
Feb 17, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 107 additions & 6 deletions akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
package caliban

import scala.concurrent.ExecutionContext
import akka.http.scaladsl.model.MediaTypes.`application/json`
import akka.http.scaladsl.model.ws.{ Message, TextMessage }
import akka.http.scaladsl.model.{ HttpEntity, HttpResponse, StatusCodes }
import akka.http.scaladsl.server.Directives.complete
import akka.http.scaladsl.server.{ Route, StandardRoute }
import akka.stream.scaladsl.{ Flow, Sink, Source, SourceQueueWithComplete }
import akka.stream.{ Materializer, OverflowStrategy, QueueOfferResult }
import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.Value.NullValue
import de.heikoseeberger.akkahttpcirce.FailFastCirceSupport
import io.circe.Decoder.Result
import io.circe.Json
import io.circe.parser._
import io.circe.syntax._
import zio.{ Runtime, URIO }

import scala.concurrent.ExecutionContext
import zio.{ Fiber, IO, Ref, Runtime, Task, URIO }

object AkkaHttpAdapter extends FailFastCirceSupport {

private def execute[R, E](interpreter: GraphQLInterpreter[R, E], query: GraphQLRequest): URIO[R, GraphQLResponse[E]] =
interpreter.execute(query.query, query.operationName, query.variables.getOrElse(Map()))

private def executeHttpResponse[R, E](
interpreter: GraphQLInterpreter[R, E],
request: GraphQLRequest
): URIO[R, HttpResponse] =
interpreter
.execute(request.query, request.operationName, request.variables.getOrElse(Map()))
execute(interpreter, request)
.foldCause(cause => GraphQLResponse(NullValue, cause.defects).asJson, _.asJson)
.map(gqlResult => HttpResponse(StatusCodes.OK, entity = HttpEntity(`application/json`, gqlResult.toString())))

def getGraphQLRequest(query: String, op: Option[String], vars: Option[String]): Result[GraphQLRequest] = {
import io.circe.parser._
val variablesJs = vars.flatMap(parse(_).toOption)
val fields = List("query" -> Json.fromString(query)) ++
op.map(o => "operationName" -> Json.fromString(o)) ++
@@ -60,4 +65,100 @@ object AkkaHttpAdapter extends FailFastCirceSupport {
entity(as[GraphQLRequest]) { completeRequest(interpreter) }
}
}

def makeWebSocketService[R, E](
interpreter: GraphQLInterpreter[R, E]
)(implicit ec: ExecutionContext, runtime: Runtime[R], materializer: Materializer): Route = {
def sendMessage(
sendQueue: SourceQueueWithComplete[Message],
id: String,
data: ResponseValue,
errors: List[E]
): Task[QueueOfferResult] =
IO.fromFuture(
_ =>
sendQueue.offer(
TextMessage(
Json
.obj(
"id" -> Json.fromString(id),
"type" -> Json.fromString("data"),
"payload" -> GraphQLResponse(data, errors).asJson
)
.noSpaces
)
)
)

import akka.http.scaladsl.server.Directives._

get {
extractUpgradeToWebSocket { upgrade =>
val (queue, source) = Source.queue[Message](0, OverflowStrategy.fail).preMaterialize()
val subscriptions = runtime.unsafeRun(Ref.make(Map.empty[String, Fiber[Throwable, Unit]]))
val sink = Sink.foreach[Message] {
case TextMessage.Strict(text) =>
val io = for {
msg <- Task.fromEither(decode[Json](text))
msgType = msg.hcursor.downField("type").success.flatMap(_.value.asString).getOrElse("")
_ <- IO.whenCase(msgType) {
case "connection_init" =>
IO.fromFuture(_ => queue.offer(TextMessage("""{"type":"connection_ack"}""")))
case "connection_terminate" =>
IO.effect(queue.complete())
case "start" =>
val payload = msg.hcursor.downField("payload")
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
Task.whenCase(payload.downField("query").success.flatMap(_.value.asString)) {
case Some(query) =>
val operationName = payload.downField("operationName").success.flatMap(_.value.asString)
(for {
result <- execute(interpreter, GraphQLRequest(query, operationName, None))
_ <- result.data match {
case ObjectValue((fieldName, StreamValue(stream)) :: Nil) =>
stream.foreach { item =>
sendMessage(queue, id, ObjectValue(List(fieldName -> item)), result.errors)
}.fork.flatMap(fiber => subscriptions.update(_.updated(id, fiber)))
case other =>
sendMessage(queue, id, other, result.errors) *> IO.fromFuture(
_ => queue.offer(TextMessage(s"""{"type":"complete","id":"$id"}"""))
)
}
} yield ()).catchAll(
error =>
IO.fromFuture(
_ =>
queue.offer(
TextMessage(
Json
.obj(
"id" -> Json.fromString(id),
"type" -> Json.fromString("complete"),
"payload" -> Json.fromString(error.toString)
)
.noSpaces
)
)
)
)
}
case "stop" =>
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
subscriptions
.modify(map => (map.get(id), map - id))
.flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt })
}
} yield ()
runtime.unsafeRun(io)
case _ => ()
}

val flow = Flow.fromSinkAndSource(sink, source).watchTermination() { (_, f) =>
f.onComplete(_ => runtime.unsafeRun(subscriptions.get.flatMap(m => IO.foreach(m.values)(_.interrupt).unit)))
}

complete(upgrade.handleMessages(flow, subprotocol = Some("graphql-ws")))
}
}
}
}
23 changes: 18 additions & 5 deletions examples/src/main/scala/caliban/akkahttp/ExampleApp.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package caliban.akkahttp

import scala.language.postfixOps
import scala.io.StdIn
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
@@ -8,9 +9,12 @@ import caliban.ExampleData.{ sampleCharacters, Character, CharacterArgs, Charact
import caliban.GraphQL.graphQL
import caliban.schema.Annotations.{ GQLDeprecated, GQLDescription }
import caliban.schema.GenericSchema
import caliban.{ AkkaHttpAdapter, ExampleService, RootResolver }
import caliban.wrappers.ApolloTracing.apolloTracing
import caliban.wrappers.Wrappers._
import caliban.{ AkkaHttpAdapter, ExampleService, GraphQL, RootResolver }
import zio.clock.Clock
import zio.console.Console
import zio.duration._
import zio.stream.ZStream
import zio.{ DefaultRuntime, URIO }

@@ -34,9 +38,7 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
case class Mutations(deleteCharacter: CharacterArgs => URIO[Console, Boolean])
case class Subscriptions(characterDeleted: ZStream[Console, Nothing, String])

val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters))

val interpreter =
def makeApi(service: ExampleService): GraphQL[Console with Clock] =
graphQL(
RootResolver(
Queries(
@@ -46,7 +48,16 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
Mutations(args => service.deleteCharacter(args.name)),
Subscriptions(service.deletedEvents)
)
).interpreter
) @@
maxFields(200) @@ // query analyzer that limit query fields
maxDepth(30) @@ // query analyzer that limit query depth
timeout(3 seconds) @@ // wrapper that fails slow queries
printSlowQueries(500 millis) @@ // wrapper that logs slow queries
apolloTracing // wrapper for https://github.com/apollographql/apollo-tracing

val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters))

val interpreter = makeApi(service).interpreter

/**
* curl -X POST \
@@ -60,6 +71,8 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
val route =
path("api" / "graphql") {
AkkaHttpAdapter.makeHttpService(interpreter)
} ~ path("ws" / "graphql") {
AkkaHttpAdapter.makeWebSocketService(interpreter)
} ~ path("graphiql") {
getFromResource("graphiql.html")
}
4 changes: 3 additions & 1 deletion http4s/src/main/scala/caliban/Http4sAdapter.scala
Original file line number Diff line number Diff line change
@@ -152,7 +152,9 @@ object Http4sAdapter {
}
case "stop" =>
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
subscriptions.get.flatMap(map => IO.whenCase(map.get(id)) { case Some(fiber) => fiber.interrupt })
subscriptions
.modify(map => (map.get(id), map - id))
.flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt })
}
} yield ()
}