diff --git a/adapters/quick/src/main/scala/caliban/GraphiQLHandler.scala b/adapters/quick/src/main/scala/caliban/GraphiQLHandler.scala index c735c2fda9..0c08358ce2 100644 --- a/adapters/quick/src/main/scala/caliban/GraphiQLHandler.scala +++ b/adapters/quick/src/main/scala/caliban/GraphiQLHandler.scala @@ -13,7 +13,11 @@ object GraphiQLHandler { * @see [[https://github.com/graphql/graphiql/tree/main/examples/graphiql-cdn]] */ def handler(apiPath: String, graphiqlPath: String): RequestHandler[Any, Nothing] = - Handler.fromBody(Body.fromString(html(apiPath, graphiqlPath))) + Response( + Status.Ok, + Headers(Header.ContentType(MediaType.text.html).untyped), + Body.fromString(html(apiPath, graphiqlPath)) + ).toHandler def html(apiPath: String, uiPath: String): String = s""" diff --git a/adapters/quick/src/main/scala/caliban/QuickAdapter.scala b/adapters/quick/src/main/scala/caliban/QuickAdapter.scala index 6c8b5e8c0b..0ba81f9915 100644 --- a/adapters/quick/src/main/scala/caliban/QuickAdapter.scala +++ b/adapters/quick/src/main/scala/caliban/QuickAdapter.scala @@ -11,7 +11,8 @@ final class QuickAdapter[-R] private (requestHandler: QuickRequestHandler[R]) { */ val handlers: QuickHandlers[R] = QuickHandlers( api = Handler.fromFunctionZIO[Request](requestHandler.handleHttpRequest), - upload = Handler.fromFunctionZIO[Request](requestHandler.handleUploadRequest) + upload = Handler.fromFunctionZIO[Request](requestHandler.handleUploadRequest), + webSocket = Handler.fromFunctionZIO[Request](requestHandler.handleWebSocketRequest) ) @deprecated("Use `handlers` instead", "2.5.0") @@ -24,11 +25,13 @@ final class QuickAdapter[-R] private (requestHandler: QuickRequestHandler[R]) { * @param apiPath The path where the GraphQL API will be served. * @param graphiqlPath The path where the GraphiQL UI will be served. If None, GraphiQL will not be served. * @param uploadPath The path where files can be uploaded. If None, uploads will be disabled. + * @param webSocketPath The path where websocket requests will be set. If None, websocket-based subscriptions will be disabled. */ def toApp( apiPath: String, graphiqlPath: Option[String] = None, - uploadPath: Option[String] = None + uploadPath: Option[String] = None, + webSocketPath: Option[String] = None ): HttpApp[R] = { val apiRoutes = List( RoutePattern(Method.POST, apiPath) -> handlers.api, @@ -40,8 +43,10 @@ final class QuickAdapter[-R] private (requestHandler: QuickRequestHandler[R]) { val uploadRoute = uploadPath.toList.map { uPath => RoutePattern(Method.POST, uPath) -> handlers.upload } - - Routes.fromIterable(apiRoutes ::: graphiqlRoute ::: uploadRoute).toHttpApp + val wsRoute = webSocketPath.toList.map { wsPath => + RoutePattern(Method.ANY, wsPath) -> handlers.webSocket + } + Routes.fromIterable(apiRoutes ::: graphiqlRoute ::: uploadRoute ::: wsRoute).toHttpApp } /** @@ -52,15 +57,17 @@ final class QuickAdapter[-R] private (requestHandler: QuickRequestHandler[R]) { * @param apiPath The route to serve the API on, e.g., `/api/graphql` * @param graphiqlPath Optionally define a route to serve the GraphiQL UI on, e.g., `/graphiql` * @param uploadPath The route where files can be uploaded, e.g., /upload/graphql. If None, uploads will be disabled. + * @param webSocketPath The path where websocket requests will be set. If None, websocket-based subscriptions will be disabled. */ def runServer( port: Int, apiPath: String, graphiqlPath: Option[String] = None, - uploadPath: Option[String] = None + uploadPath: Option[String] = None, + webSocketPath: Option[String] = None )(implicit trace: Trace): RIO[R, Nothing] = Server - .serve[R](toApp(apiPath, graphiqlPath = graphiqlPath, uploadPath = uploadPath)) + .serve[R](toApp(apiPath, graphiqlPath = graphiqlPath, uploadPath = uploadPath, webSocketPath = webSocketPath)) .provideSomeLayer[R](Server.defaultWithPort(port)) def configure(config: ExecutionConfiguration)(implicit trace: Trace): QuickAdapter[R] = @@ -69,13 +76,16 @@ final class QuickAdapter[-R] private (requestHandler: QuickRequestHandler[R]) { def configure[R1](configurator: QuickAdapter.Configurator[R1])(implicit trace: Trace): QuickAdapter[R & R1] = new QuickAdapter(requestHandler.configure[R1](configurator)) + def configureWebSocket[R1](config: quick.WebSocketConfig[R1]): QuickAdapter[R & R1] = + new QuickAdapter(requestHandler.configureWebSocket(config)) + } object QuickAdapter { type Configurator[-R] = URIO[R & Scope, Unit] def apply[R](interpreter: GraphQLInterpreter[R, Any]): QuickAdapter[R] = - new QuickAdapter(new QuickRequestHandler(interpreter)) + new QuickAdapter(new QuickRequestHandler(interpreter, quick.WebSocketConfig.default)) def handlers[R](implicit tag: Tag[R], trace: Trace): URIO[QuickAdapter[R], QuickHandlers[R]] = ZIO.serviceWith(_.handlers) diff --git a/adapters/quick/src/main/scala/caliban/QuickHandlers.scala b/adapters/quick/src/main/scala/caliban/QuickHandlers.scala index 84eac679cd..66c26ca252 100644 --- a/adapters/quick/src/main/scala/caliban/QuickHandlers.scala +++ b/adapters/quick/src/main/scala/caliban/QuickHandlers.scala @@ -4,7 +4,8 @@ import zio.http.{ HandlerAspect, RequestHandler } final case class QuickHandlers[-R]( api: RequestHandler[R, Nothing], - upload: RequestHandler[R, Nothing] + upload: RequestHandler[R, Nothing], + webSocket: RequestHandler[R, Nothing] ) { /** @@ -13,7 +14,8 @@ final case class QuickHandlers[-R]( def @@[R1 <: R](aspect: HandlerAspect[R1, Unit]): QuickHandlers[R1] = QuickHandlers( api = (api @@ aspect).merge, - upload = (upload @@ aspect).merge + upload = (upload @@ aspect).merge, + webSocket = (webSocket @@ aspect).merge ) } diff --git a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala index b5192a68fc..30e381c9b2 100644 --- a/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala +++ b/adapters/quick/src/main/scala/caliban/QuickRequestHandler.scala @@ -6,32 +6,42 @@ import caliban.ResponseValue.StreamValue import caliban.interop.jsoniter.ValueJsoniter import caliban.uploads.{ FileMeta, GraphQLUploadRequest, Uploads } import caliban.wrappers.Caching +import caliban.ws.Protocol import com.github.plokhotnyuk.jsoniter_scala.core._ import com.github.plokhotnyuk.jsoniter_scala.macros.JsonCodecMaker import zio._ +import zio.http.ChannelEvent.UserEvent.HandshakeComplete import zio.http.Header.ContentType import zio.http._ import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.stream.{ UStream, ZStream } +import zio.stream.{ UStream, ZPipeline, ZStream } import java.nio.charset.StandardCharsets.UTF_8 import scala.util.control.NonFatal -final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, Any]) { +final private class QuickRequestHandler[R]( + interpreter: GraphQLInterpreter[R, Any], + wsConfig: quick.WebSocketConfig[R] +) { import QuickRequestHandler._ def configure(config: ExecutionConfiguration)(implicit trace: Trace): QuickRequestHandler[R] = new QuickRequestHandler[R]( - interpreter.wrapExecutionWith[R, Any](Configurator.setWith(config)(_)) + interpreter.wrapExecutionWith[R, Any](Configurator.setWith(config)(_)), + wsConfig ) def configure[R1](configurator: QuickAdapter.Configurator[R1])(implicit trace: Trace ): QuickRequestHandler[R & R1] = new QuickRequestHandler[R & R1]( - interpreter.wrapExecutionWith[R & R1, Any](exec => ZIO.scoped[R1 & R](configurator *> exec)) + interpreter.wrapExecutionWith[R & R1, Any](exec => ZIO.scoped[R1 & R](configurator *> exec)), + wsConfig ) + def configureWebSocket[R1](config: quick.WebSocketConfig[R1]): QuickRequestHandler[R & R1] = + new QuickRequestHandler[R & R1](interpreter, config) + def handleHttpRequest(request: Request)(implicit trace: Trace): URIO[R, Response] = transformHttpRequest(request) .flatMap(executeRequest(request.method, _)) @@ -45,6 +55,17 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A .provideSomeLayer[R](fileHandle) }.merge + def handleWebSocketRequest(request: Request)(implicit trace: Trace): URIO[R, Response] = + Response.fromSocketApp { + val protocol = request.headers.get(Header.SecWebSocketProtocol) match { + case Some(value) => Protocol.fromName(value.renderedValue) + case None => Protocol.Legacy + } + Handler + .webSocket(webSocketChannelListener(protocol)) + .withConfig(wsConfig.zHttpConfig.subProtocol(Some(protocol.name))) + } + private def transformHttpRequest(httpReq: Request)(implicit trace: Trace): IO[Response, GraphQLRequest] = { def decodeQueryParams(queryParams: QueryParams): Either[Response, GraphQLRequest] = { @@ -214,6 +235,28 @@ final private class QuickRequestHandler[-R](interpreter: GraphQLInterpreter[R, A req.headers .get(GraphQLRequest.`apollo-federation-include-trace`) .exists(_.equalsIgnoreCase(GraphQLRequest.ftv1)) + + private def webSocketChannelListener(protocol: Protocol)(ch: WebSocketChannel)(implicit trace: Trace): RIO[R, Unit] = + for { + queue <- Queue.unbounded[GraphQLWSInput] + pipe <- protocol.make(interpreter, wsConfig.keepAliveTime, wsConfig.hooks).map(ZPipeline.fromFunction(_)) + out = ZStream + .fromQueueWithShutdown(queue) + .via(pipe) + .interruptWhen(ch.awaitShutdown) + .map { + case Right(output) => WebSocketFrame.Text(writeToString(output)) + case Left(close) => WebSocketFrame.Close(close.code, Some(close.reason)) + } + _ <- ZIO.scoped(ch.receiveAll { + case ChannelEvent.UserEventTriggered(HandshakeComplete) => + out.runForeach(frame => ch.send(ChannelEvent.Read(frame))).forkScoped + case ChannelEvent.Read(WebSocketFrame.Text(text)) => + ZIO.suspend(queue.offer(readFromString[GraphQLWSInput](text))) + case _ => + ZIO.unit + }) + } yield () } object QuickRequestHandler { diff --git a/adapters/quick/src/main/scala/caliban/quick/WebSocketConfig.scala b/adapters/quick/src/main/scala/caliban/quick/WebSocketConfig.scala new file mode 100644 index 0000000000..5535f3cfc2 --- /dev/null +++ b/adapters/quick/src/main/scala/caliban/quick/WebSocketConfig.scala @@ -0,0 +1,24 @@ +package caliban.quick + +import caliban.ws.WebSocketHooks +import zio._ +import zio.http.{ WebSocketConfig => ZWebSocketConfig } + +case class WebSocketConfig[-R]( + keepAliveTime: Option[Duration], + hooks: WebSocketHooks[R, Any], + zHttpConfig: ZWebSocketConfig +) { + def withHooks[R1](newHooks: WebSocketHooks[R1, Any]): WebSocketConfig[R & R1] = + copy(hooks = hooks ++ newHooks) + + def withKeepAliveTime(time: Duration): WebSocketConfig[R] = + copy(keepAliveTime = Some(time)) + + def withZHttpConfig(newConfig: ZWebSocketConfig): WebSocketConfig[R] = + copy(zHttpConfig = newConfig) +} + +object WebSocketConfig { + def default: WebSocketConfig[Any] = WebSocketConfig(None, WebSocketHooks.empty, ZWebSocketConfig.default) +} diff --git a/adapters/quick/src/main/scala/caliban/quick/package.scala b/adapters/quick/src/main/scala/caliban/quick/package.scala index a8929e7472..1b923cf3ef 100644 --- a/adapters/quick/src/main/scala/caliban/quick/package.scala +++ b/adapters/quick/src/main/scala/caliban/quick/package.scala @@ -16,16 +16,26 @@ package object quick { * @param apiPath The route to serve the API on, e.g., `/api/graphql` * @param graphiqlPath Optionally define a route to serve the GraphiQL UI on, e.g., `/graphiql` * @param uploadPath Optionally define a route to serve file uploads on, e.g., `/api/upload` + * @param webSocketPath The path where websocket requests will be set. If None, websocket-based subscriptions will be disabled. */ def runServer( port: Int, apiPath: String, graphiqlPath: Option[String] = None, - uploadPath: Option[String] = None + uploadPath: Option[String] = None, + webSocketPath: Option[String] = None )(implicit trace: Trace ): RIO[R, Nothing] = - gql.interpreter.flatMap(QuickAdapter(_).runServer(port, apiPath, graphiqlPath, uploadPath)) + gql.interpreter.flatMap( + QuickAdapter(_).runServer( + port, + apiPath = apiPath, + graphiqlPath = graphiqlPath, + uploadPath = uploadPath, + webSocketPath = webSocketPath + ) + ) /** * Creates zio-http `HttpApp` from the GraphQL API @@ -37,9 +47,17 @@ package object quick { def toApp( apiPath: String, graphiqlPath: Option[String] = None, - uploadPath: Option[String] = None + uploadPath: Option[String] = None, + webSocketPath: Option[String] = None )(implicit trace: Trace): IO[CalibanError.ValidationError, HttpApp[R]] = - gql.interpreter.map(QuickAdapter(_).toApp(apiPath, graphiqlPath, uploadPath)) + gql.interpreter.map( + QuickAdapter(_).toApp( + apiPath = apiPath, + graphiqlPath = graphiqlPath, + uploadPath = uploadPath, + webSocketPath = webSocketPath + ) + ) /** * Creates a zio-http handler for the GraphQL API diff --git a/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala b/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala index 65c5413bfd..0e8925badf 100644 --- a/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala +++ b/adapters/quick/src/test/scala/caliban/QuickAdapterSpec.scala @@ -24,7 +24,9 @@ object QuickAdapterSpec extends ZIOSpecDefault { private val apiLayer = envLayer >>> ZLayer.fromZIO { for { - app <- TestApi.api.toApp("/api/graphql", uploadPath = Some("/upload/graphql")).map(_ @@ auth) + app <- TestApi.api + .toApp("/api/graphql", uploadPath = Some("/upload/graphql"), webSocketPath = Some("/ws/graphql")) + .map(_ @@ auth) _ <- Server.serve(app).forkScoped _ <- Live.live(Clock.sleep(3 seconds)) service <- ZIO.service[TestService] @@ -35,7 +37,7 @@ object QuickAdapterSpec extends ZIOSpecDefault { val suite = TapirAdapterSpec.makeSuite( "QuickAdapterSpec", uri"http://localhost:8090/api/graphql", - wsUri = None, + wsUri = Some(uri"ws://localhost:8090/ws/graphql"), uploadUri = Some(uri"http://localhost:8090/upload/graphql") ) suite.provideShared( diff --git a/adapters/zio-http/src/main/scala/caliban/ZHttpAdapter.scala b/adapters/zio-http/src/main/scala/caliban/ZHttpAdapter.scala index 29cacb7a36..9a1650b684 100644 --- a/adapters/zio-http/src/main/scala/caliban/ZHttpAdapter.scala +++ b/adapters/zio-http/src/main/scala/caliban/ZHttpAdapter.scala @@ -1,12 +1,16 @@ package caliban -import caliban.interop.tapir.ws.Protocol import caliban.interop.tapir.{ HttpInterpreter, WebSocketInterpreter } +import caliban.ws.Protocol import sttp.capabilities.zio.ZioStreams import sttp.model.HeaderNames import sttp.tapir.server.ziohttp.{ ZioHttpInterpreter, ZioHttpServerOptions } import zio.http._ +@deprecated( + "The `caliban-zio-http` package is deprecated and scheduled to be removed in a future release. To use Caliban with zio-http, use the `caliban-quick` module instead", + "2.6.0" +) object ZHttpAdapter { @deprecated("Defining subprotocols in the server config is no longer required") diff --git a/adapters/zio-http/src/test/scala/caliban/ZHttpAdapterSpec.scala b/adapters/zio-http/src/test/scala/caliban/ZHttpAdapterSpec.scala index 1350aec791..72f3b4d1f3 100644 --- a/adapters/zio-http/src/test/scala/caliban/ZHttpAdapterSpec.scala +++ b/adapters/zio-http/src/test/scala/caliban/ZHttpAdapterSpec.scala @@ -15,8 +15,10 @@ import zio._ import zio.http._ import zio.test.{ Live, ZIOSpecDefault } +import scala.annotation.nowarn import scala.language.postfixOps +@nowarn object ZHttpAdapterSpec extends ZIOSpecDefault { import sttp.tapir.json.zio._ diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala b/core/src/main/scala/caliban/ws/Protocol.scala similarity index 95% rename from interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala rename to core/src/main/scala/caliban/ws/Protocol.scala index 2a5345b795..b02eb219da 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/ws/Protocol.scala +++ b/core/src/main/scala/caliban/ws/Protocol.scala @@ -1,10 +1,8 @@ -package caliban.interop.tapir.ws +package caliban.ws import caliban.ResponseValue.{ ObjectValue, StreamValue } import caliban.Value.StringValue import caliban._ -import caliban.interop.tapir.TapirAdapter.CalibanPipe -import caliban.interop.tapir.WebSocketHooks import zio.stm.TMap import zio.stream.{ UStream, ZStream } import zio.{ Duration, Promise, Queue, Random, Ref, Schedule, UIO, URIO, ZIO } @@ -22,10 +20,9 @@ sealed trait Protocol { object Protocol { - def fromName(name: String): Protocol = name match { - case GraphQLWS.name => GraphQLWS - case _ => Legacy - } + def fromName(name: String): Protocol = + if (name.equalsIgnoreCase(GraphQLWS.name)) GraphQLWS + else Legacy object GraphQLWS extends Protocol { object Ops { @@ -39,9 +36,9 @@ object Protocol { final val ConnectionAck = "connection_ack" } - val name = "graphql-transport-ws" + final val name = "graphql-transport-ws" - val handler = new ResponseHandler { + private val handler: ResponseHandler = new ResponseHandler { override def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput = GraphQLWSOutput(Ops.Next, Some(id), Some(r.toResponseValue)) @@ -126,8 +123,7 @@ object Protocol { ZIO.ifZIO(subscriptions.isTracking(id))( output.offer(Left(GraphQLWSClose(4409, s"Subscriber for $id already exists"))).unit, webSocketHooks.onMessage - .map(_.transform(stream)) - .getOrElse(stream) + .fold(stream)(stream.via(_)) .map(Right(_)) .runForeachChunk(output.offerAll) .catchAll(e => output.offer(Right(handler.error(Some(id), e)))) @@ -190,9 +186,9 @@ object Protocol { final val Data = "data" } - val name = "graphql-ws" + final val name = "graphql-ws" - val handler: ResponseHandler = new ResponseHandler { + private val handler: ResponseHandler = new ResponseHandler { override def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput = GraphQLWSOutput(Ops.Data, Some(id), Some(r.toResponseValue)) @@ -258,8 +254,7 @@ object Protocol { val stream = handler.generateGraphQLResponse(req, id.getOrElse(""), interpreter, subscriptions) webSocketHooks.onMessage - .map(_.transform(stream)) - .getOrElse(stream) + .fold(stream)(stream.via(_)) .runForeachChunk(o => output.offerAll(o.map(Right(_)))) .catchAll(e => output.offer(Right(handler.error(id, e)))) .fork @@ -298,7 +293,7 @@ object Protocol { GraphQLWSOutput(Ops.ConnectionAck, None, payload) } - private[ws] trait ResponseHandler { + private trait ResponseHandler { self => def toResponse[E](id: String, fieldName: String, r: ResponseValue, errors: List[E]): GraphQLWSOutput = toResponse(id, GraphQLResponse(ObjectValue(List(fieldName -> r)), errors)) @@ -339,7 +334,7 @@ object Protocol { } } - private[ws] class SubscriptionManager private (private val tracked: TMap[String, Promise[Any, Unit]]) { + private class SubscriptionManager private (private val tracked: TMap[String, Promise[Any, Unit]]) { def track(id: String): UStream[Promise[Any, Unit]] = ZStream.fromZIO(Promise.make[Any, Unit].tap(tracked.put(id, _).commit)) diff --git a/core/src/main/scala/caliban/ws/WebSocketHooks.scala b/core/src/main/scala/caliban/ws/WebSocketHooks.scala new file mode 100644 index 0000000000..2b02531891 --- /dev/null +++ b/core/src/main/scala/caliban/ws/WebSocketHooks.scala @@ -0,0 +1,117 @@ +package caliban.ws + +import caliban.{ GraphQLWSOutput, InputValue, ResponseValue } +import zio.ZIO +import zio.stream.ZPipeline + +trait WebSocketHooks[-R, +E] { self => + def beforeInit: Option[InputValue => ZIO[R, E, Any]] = None + def afterInit: Option[ZIO[R, E, Any]] = None + def onMessage: Option[ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]] = None + def onPong: Option[InputValue => ZIO[R, E, Any]] = None + def onPing: Option[Option[InputValue] => ZIO[R, E, Option[ResponseValue]]] = None + def onAck: Option[ZIO[R, E, ResponseValue]] = None + + def ++[R2 <: R, E2 >: E](other: WebSocketHooks[R2, E2]): WebSocketHooks[R2, E2] = + new WebSocketHooks[R2, E2] { + override def beforeInit: Option[InputValue => ZIO[R2, E2, Any]] = (self.beforeInit, other.beforeInit) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => Some((x: InputValue) => f1(x) *> f2(x)) + case _ => None + } + + override def afterInit: Option[ZIO[R2, E2, Any]] = (self.afterInit, other.afterInit) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => Some(f1 &> f2) + case _ => None + } + + override def onMessage: Option[ZPipeline[R2, E2, GraphQLWSOutput, GraphQLWSOutput]] = + (self.onMessage, other.onMessage) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => Some(f1.andThen(f2)) + case _ => None + } + + override def onPong: Option[InputValue => ZIO[R2, E2, Any]] = (self.onPong, other.onPong) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => Some((x: InputValue) => f1(x) &> f2(x)) + case _ => None + } + + override def onPing: Option[Option[InputValue] => ZIO[R2, E2, Option[ResponseValue]]] = + (self.onPing, other.onPing) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => + Some { (x: Option[InputValue]) => + f1(x).zipWithPar(f2(x)) { + case (a @ Some(_), None) => a + case (None, b @ Some(_)) => b + case (Some(a), Some(b)) => Some(a.deepMerge(b)) + case _ => None + } + } + case _ => None + } + + override def onAck: Option[ZIO[R2, E2, ResponseValue]] = (self.onAck, other.onAck) match { + case (None, Some(f)) => Some(f) + case (Some(f), None) => Some(f) + case (Some(f1), Some(f2)) => Some((f1 zipWithPar f2)(_ deepMerge _)) + case _ => None + } + } +} + +object WebSocketHooks { + def empty[R, E]: WebSocketHooks[R, E] = new WebSocketHooks[R, E] {} + + /** + * Specifies a callback that will be run before an incoming subscription + * request is accepted. Useful for e.g authorizing the incoming subscription + * before accepting it. + */ + def init[R, E](f: InputValue => ZIO[R, E, Any]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def beforeInit: Option[InputValue => ZIO[R, E, Any]] = Some(f) + } + + /** + * Specifies a callback that will be run after an incoming subscription + * request has been accepted. Useful for e.g terminating a subscription + * after some time, such as authorization expiring. + */ + def afterInit[R, E](f: ZIO[R, E, Any]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def afterInit: Option[ZIO[R, E, Any]] = Some(f) + } + + /** + * Specifies a ZPipeline that will be applied to the resulting `ZStream` + * for every active subscription. Useful to e.g modify the environment + * to inject session information into the `ZStream` handling the + * subscription. + */ + def message[R, E](f: ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def onMessage: Option[ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]] = Some(f) + } + + /** + * Specifies a callback that will be run when ever a pong message is received. + */ + def pong[R, E](f: InputValue => ZIO[R, E, Any]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def onPong: Option[InputValue => ZIO[R, E, Any]] = Some(f) + } + + def ack[R, E](f: ZIO[R, E, ResponseValue]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def onAck: Option[ZIO[R, E, ResponseValue]] = Some(f) + } +} diff --git a/core/src/main/scala/caliban/ws/package.scala b/core/src/main/scala/caliban/ws/package.scala new file mode 100644 index 0000000000..7ac89657ef --- /dev/null +++ b/core/src/main/scala/caliban/ws/package.scala @@ -0,0 +1,8 @@ +package caliban + +import zio.stream.Stream + +package object ws { + type Pipe[A, B] = Stream[Throwable, A] => Stream[Throwable, B] + type CalibanPipe = Pipe[GraphQLWSInput, Either[GraphQLWSClose, GraphQLWSOutput]] +} diff --git a/examples/src/main/scala/example/quick/ExampleApp.scala b/examples/src/main/scala/example/quick/ExampleApp.scala index 7f258398ac..1a550169b5 100644 --- a/examples/src/main/scala/example/quick/ExampleApp.scala +++ b/examples/src/main/scala/example/quick/ExampleApp.scala @@ -1,10 +1,10 @@ package example.quick import caliban._ +import caliban.quick._ import example.ExampleData._ import example.{ ExampleApi, ExampleService } import zio._ -import caliban.quick._ object ExampleApp extends ZIOAppDefault { @@ -20,7 +20,8 @@ object ExampleApp extends ZIOAppDefault { _.runServer( port = 8090, apiPath = "/api/graphql", - graphiqlPath = Some("/graphiql") + graphiqlPath = Some("/graphiql"), + webSocketPath = Some("/ws/graphql") ) } .provide( diff --git a/examples/src/main/scala/example/stitching/ExampleApp.scala b/examples/src/main/scala/example/stitching/ExampleApp.scala index 35c774dedd..e5eaee71bc 100644 --- a/examples/src/main/scala/example/stitching/ExampleApp.scala +++ b/examples/src/main/scala/example/stitching/ExampleApp.scala @@ -1,12 +1,12 @@ package example.stitching import caliban._ -import caliban.interop.tapir.{ HttpInterpreter, WebSocketInterpreter } -import caliban.schema._ -import caliban.schema.Schema.auto._ +import caliban.quick._ import caliban.schema.ArgBuilder.auto._ -import caliban.tools.{ Options, RemoteSchema, SchemaLoader } +import caliban.schema.Schema.auto._ +import caliban.schema._ import caliban.tools.stitching.{ HttpRequest, RemoteResolver, RemoteSchemaResolver, ResolveRequest } +import caliban.tools.{ Options, RemoteSchema, SchemaLoader } import sttp.capabilities.WebSockets import sttp.capabilities.zio.ZioStreams import sttp.client3.SttpBackend @@ -98,35 +98,17 @@ object Configuration { private def read(key: String): Task[String] = ZIO.attempt(sys.env(key)) } -import zio.stream._ -import zio.http._ -import caliban.ZHttpAdapter - object ExampleApp extends ZIOAppDefault { - import sttp.tapir.json.circe._ - - private val graphiql = Handler.fromResource("graphiql.html").sandbox - def run = - (for { - api <- StitchingExample.api - interpreter <- api.interpreter - _ <- - Server - .serve( - Routes( - Method.ANY / "api" / "graphql" -> - ZHttpAdapter.makeHttpService(HttpInterpreter(interpreter)), - Method.ANY / "ws" / "graphql" -> - ZHttpAdapter.makeWebSocketService(WebSocketInterpreter(interpreter)), - Method.ANY / "graphiql" -> - graphiql - ).toHttpApp - ) - } yield ()) - .provide( - HttpClientZioBackend.layer(), - Configuration.fromEnvironment, - Server.default + StitchingExample.api.flatMap { + _.runServer( + port = 8080, + apiPath = "/api/graphql", + graphiqlPath = Some("/graphiql"), + webSocketPath = Some("/ws/graphql") ) + }.provide( + HttpClientZioBackend.layer(), + Configuration.fromEnvironment + ) } diff --git a/examples/src/main/scala/example/ziohttp/AuthExampleApp.scala b/examples/src/main/scala/example/ziohttp/AuthExampleApp.scala index 8466e9f099..95a05b6e56 100644 --- a/examples/src/main/scala/example/ziohttp/AuthExampleApp.scala +++ b/examples/src/main/scala/example/ziohttp/AuthExampleApp.scala @@ -2,8 +2,9 @@ package example.ziohttp import caliban.Value.StringValue import caliban._ -import caliban.interop.tapir.{ HttpInterpreter, WebSocketHooks, WebSocketInterpreter } +import caliban.interop.tapir.{ HttpInterpreter, WebSocketInterpreter } import caliban.schema.GenericSchema +import caliban.ws.WebSocketHooks import example.ExampleData._ import example.{ ExampleApi, ExampleService } import sttp.tapir.json.circe._ @@ -11,6 +12,8 @@ import zio._ import zio.http._ import zio.stream._ +import scala.annotation.nowarn + case object Unauthorized extends RuntimeException("Unauthorized") trait Auth { @@ -20,6 +23,7 @@ trait Auth { def setUser(name: Option[String]): UIO[Unit] } +@nowarn object Auth { val http: ULayer[Auth] = ZLayer.scoped { @@ -81,6 +85,7 @@ object Authed extends GenericSchema[Auth] { val api = graphQL(RootResolver(Queries(), None, Subscriptions())) } +@nowarn object AuthExampleApp extends ZIOAppDefault { private val graphiql = Handler.fromResource("graphiql.html").sandbox diff --git a/examples/src/main/scala/example/ziohttp/ExampleApp.scala b/examples/src/main/scala/example/ziohttp/ExampleApp.scala index 3497b3a9e3..a6fee68349 100644 --- a/examples/src/main/scala/example/ziohttp/ExampleApp.scala +++ b/examples/src/main/scala/example/ziohttp/ExampleApp.scala @@ -7,6 +7,9 @@ import caliban.interop.tapir.{ HttpInterpreter, WebSocketInterpreter } import zio._ import zio.http._ +import scala.annotation.nowarn + +@nowarn object ExampleApp extends ZIOAppDefault { import sttp.tapir.json.circe._ diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala index edaffca387..4afc9fb6a9 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/TapirAdapter.scala @@ -1,13 +1,12 @@ package caliban.interop.tapir +import caliban.ResponseValue.StreamValue import caliban._ -import caliban.ResponseValue.{ ObjectValue, StreamValue } import caliban.wrappers.Caching import sttp.capabilities.zio.ZioStreams -import sttp.capabilities.zio.ZioStreams.Pipe import sttp.capabilities.{ Streams, WebSockets } -import sttp.model.{ headers => _, _ } import sttp.model.sse.ServerSentEvent +import sttp.model.{ headers => _, _ } import sttp.monad.MonadError import sttp.tapir.Codec.JsonCodec import sttp.tapir.model.ServerRequest @@ -22,7 +21,7 @@ import scala.concurrent.Future object TapirAdapter { - type CalibanPipe = Pipe[GraphQLWSInput, Either[GraphQLWSClose, GraphQLWSOutput]] + type CalibanPipe = caliban.ws.CalibanPipe type UploadRequest = (Seq[Part[Array[Byte]]], ServerRequest) type ZioWebSockets = ZioStreams with WebSockets @@ -159,31 +158,10 @@ object TapirAdapter { private val deferMultipartMediaType: MediaType = MediaType.MultipartMixed.copy(otherParameters = HttpUtils.DeferMultipart.DeferHeaderParams) - @deprecated("Kept for binary compatibility purposes. To be removed in 2.5.0", "2.4.3") - private object DeferMultipart { - private val Newline = "\r\n" - private val ContentType = "Content-Type: application/json; charset=utf-8" - private val SubHeader = s"$Newline$ContentType$Newline$Newline" - private val Boundary = "---" - private val BoundaryHeader = "-" - private val DeferSpec = "20220824" - - val InnerBoundary = s"$Newline$Boundary$SubHeader" - val EndBoundary = s"$Newline-----$Newline" - - private val DeferHeaderParams: Map[String, String] = Map("boundary" -> BoundaryHeader, "deferSpec" -> DeferSpec) - - val mediaType: MediaType = MediaType.MultipartMixed.copy(otherParameters = DeferHeaderParams) - } - private object GraphqlResponseJson extends CodecFormat { override val mediaType: MediaType = MediaType("application", "graphql-response+json") } - private object GraphqlServerSentEvent { - val mediaType: MediaType = MediaType.TextEventStream - } - private def encodeMultipartMixedResponse[E, BS]( resp: GraphQLResponse[E], stream: ZStream[Any, Throwable, ResponseValue] diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketHooks.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketHooks.scala index 8e6c780229..b8f37ea60f 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketHooks.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketHooks.scala @@ -2,80 +2,20 @@ package caliban.interop.tapir import caliban.{ GraphQLWSOutput, InputValue, ResponseValue } import zio.ZIO -import zio.stream.ZStream +import zio.stream.{ ZPipeline, ZStream } +@deprecated( + "WebSocketHooks.onMessage now uses a ZPipeline instead. To convert your existing logic into a ZPipeline, use `ZPipeline.fromFunction`", + "2.6.0" +) trait StreamTransformer[-R, +E] { def transform[R1 <: R, E1 >: E](stream: ZStream[R1, E1, GraphQLWSOutput]): ZStream[R1, E1, GraphQLWSOutput] } -trait WebSocketHooks[-R, +E] { self => - def beforeInit: Option[InputValue => ZIO[R, E, Any]] = None - def afterInit: Option[ZIO[R, E, Any]] = None - def onMessage: Option[StreamTransformer[R, E]] = None - def onPong: Option[InputValue => ZIO[R, E, Any]] = None - def onPing: Option[Option[InputValue] => ZIO[R, E, Option[ResponseValue]]] = None - def onAck: Option[ZIO[R, E, ResponseValue]] = None - - def ++[R2 <: R, E2 >: E](other: WebSocketHooks[R2, E2]): WebSocketHooks[R2, E2] = - new WebSocketHooks[R2, E2] { - override def beforeInit: Option[InputValue => ZIO[R2, E2, Any]] = (self.beforeInit, other.beforeInit) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => Some((x: InputValue) => f1(x) *> f2(x)) - case _ => None - } - - override def afterInit: Option[ZIO[R2, E2, Any]] = (self.afterInit, other.afterInit) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => Some(f1 &> f2) - case _ => None - } - - override def onMessage: Option[StreamTransformer[R2, E2]] = - (self.onMessage, other.onMessage) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => - Some(new StreamTransformer[R2, E2] { - def transform[R1 <: R2, E1 >: E2](s: ZStream[R1, E1, GraphQLWSOutput]): ZStream[R1, E1, GraphQLWSOutput] = - f2.transform(f1.transform(s)) - }) - case _ => None - } - - override def onPong: Option[InputValue => ZIO[R2, E2, Any]] = (self.onPong, other.onPong) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => Some((x: InputValue) => f1(x) &> f2(x)) - case _ => None - } - - override def onPing: Option[Option[InputValue] => ZIO[R2, E2, Option[ResponseValue]]] = - (self.onPing, other.onPing) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => - Some { (x: Option[InputValue]) => - f1(x).zipWithPar(f2(x)) { - case (a @ Some(_), None) => a - case (None, b @ Some(_)) => b - case (Some(a), Some(b)) => Some(a.deepMerge(b)) - case _ => None - } - } - case _ => None - } - - override def onAck: Option[ZIO[R2, E2, ResponseValue]] = (self.onAck, other.onAck) match { - case (None, Some(f)) => Some(f) - case (Some(f), None) => Some(f) - case (Some(f1), Some(f2)) => Some((f1 zipWithPar f2)(_ deepMerge _)) - case _ => None - } - } -} +@deprecated("Use caliban.ws.WebSocketHooks instead", "2.6.0") +trait WebSocketHooks[-R, +E] extends caliban.ws.WebSocketHooks[R, E] +@deprecated("Use caliban.ws.WebSocketHooks instead", "2.6.0") object WebSocketHooks { def empty[R, E]: WebSocketHooks[R, E] = new WebSocketHooks[R, E] {} @@ -107,7 +47,19 @@ object WebSocketHooks { */ def message[R, E](f: StreamTransformer[R, E]): WebSocketHooks[R, E] = new WebSocketHooks[R, E] { - override def onMessage: Option[StreamTransformer[R, E]] = Some(f) + override def onMessage: Option[ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]] = + Some(ZPipeline.fromFunction(f.transform)) + } + + /** + * Specifies a ZPipeline that will be applied on the resulting `ZStream` + * for every active subscription. Useful to e.g modify the environment + * to inject session information into the `ZStream` handling the + * subscription. + */ + def message[R, E](f: ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]): WebSocketHooks[R, E] = + new WebSocketHooks[R, E] { + override def onMessage: Option[ZPipeline[R, E, GraphQLWSOutput, GraphQLWSOutput]] = Some(f) } /** diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketInterpreter.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketInterpreter.scala index 5f2918de4f..90be964f8b 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketInterpreter.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/WebSocketInterpreter.scala @@ -2,7 +2,7 @@ package caliban.interop.tapir import caliban._ import caliban.interop.tapir.TapirAdapter._ -import caliban.interop.tapir.ws.Protocol +import caliban.ws.Protocol import sttp.capabilities.zio.ZioStreams import sttp.model.{ headers => _ } import sttp.tapir.Codec.JsonCodec @@ -36,7 +36,7 @@ object WebSocketInterpreter { private case class Base[R, E]( interpreter: GraphQLInterpreter[R, E], keepAliveTime: Option[Duration], - webSocketHooks: WebSocketHooks[R, E] + webSocketHooks: ws.WebSocketHooks[R, E] )(implicit inputCodec: JsonCodec[GraphQLWSInput], outputCodec: JsonCodec[GraphQLWSOutput] @@ -77,7 +77,7 @@ object WebSocketInterpreter { def apply[R, E]( interpreter: GraphQLInterpreter[R, E], keepAliveTime: Option[Duration] = None, - webSocketHooks: WebSocketHooks[R, E] = WebSocketHooks.empty[R, E] + webSocketHooks: ws.WebSocketHooks[R, E] = ws.WebSocketHooks.empty[R, E] )(implicit inputCodec: JsonCodec[GraphQLWSInput], outputCodec: JsonCodec[GraphQLWSOutput] diff --git a/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala b/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala index d78f28007b..1332e08c35 100644 --- a/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala +++ b/interop/tapir/src/test/scala/caliban/interop/tapir/TapirAdapterSpec.scala @@ -7,8 +7,8 @@ import sttp.capabilities.zio.ZioStreams import sttp.capabilities.{ Effect, WebSockets } import sttp.client3.asynchttpclient.zio.AsyncHttpClientZioBackend import sttp.client3.httpclient.zio.SttpClient -import sttp.client3.{ BasicRequestBody, DeserializationException, HttpError, ResponseException, SttpBackend } import sttp.client3.impl.zio.ZioServerSentEvents +import sttp.client3.{ BasicRequestBody, DeserializationException, HttpError, ResponseException, SttpBackend } import sttp.model._ import sttp.model.sse.ServerSentEvent import sttp.tapir.Codec.JsonCodec @@ -340,7 +340,7 @@ object TapirAdapterSpec { runWS.map(runWS => suite("test ws endpoint")( test("legacy ws") { - import caliban.interop.tapir.ws.Protocol.Legacy.Ops + import caliban.ws.Protocol.Legacy.Ops val io = for { res <- ZIO.serviceWithZIO[SttpBackend[Task, ZioStreams with WebSockets]]( @@ -395,7 +395,7 @@ object TapirAdapterSpec { } } @@ TestAspect.timeout(60.seconds), test("graphql-ws") { - import caliban.interop.tapir.ws.Protocol.GraphQLWS.Ops + import caliban.ws.Protocol.GraphQLWS.Ops val io = for { res <- ZIO.serviceWithZIO[SttpBackend[Task, ZioStreams with WebSockets]](