diff --git a/src/messages/subscribe.ts b/src/messages/subscribe.ts index 922a8dd..27f34cf 100644 --- a/src/messages/subscribe.ts +++ b/src/messages/subscribe.ts @@ -1,5 +1,11 @@ import { SubscribeMessage, MessageType } from 'graphql-ws'; -import { validate, parse, execute, GraphQLError } from 'graphql'; +import { + validate, + parse, + execute, + GraphQLError, + ExecutionResult, +} from 'graphql'; import { buildExecutionContext, assertValidExecutionArguments, @@ -9,6 +15,7 @@ import { constructContext, deleteConnection, getResolverAndArgs, + isAsyncIterable, promisify, sendMessage, } from '../utils'; @@ -45,10 +52,10 @@ export const subscribe: MessageHandler = } const contextValue = await constructContext(c)({ connectionParams }); - + const query = parse(message.payload.query); const execContext = buildExecutionContext( c.schema, - parse(message.payload.query), + query, undefined, contextValue, message.payload.variables, @@ -70,25 +77,29 @@ export const subscribe: MessageHandler = } if (execContext.operation.operation !== 'subscription') { - const result = await execute( - c.schema, - parse(message.payload.query), - undefined, + const result = await execute({ + schema: c.schema, + document: query, contextValue, - message.payload.variables, - message.payload.operationName, - undefined - ); - - await sendMessage({ - ...event.requestContext, - message: { - type: MessageType.Next, - id: message.id, - payload: result, - }, + variableValues: message.payload.variables, + operationName: message.payload.operationName, }); + // Support for @defer and @stream directives + const parts = isAsyncIterable(result) + ? result + : [result]; + for await (let part of parts) { + await sendMessage({ + ...event.requestContext, + message: { + type: MessageType.Next, + id: message.id, + payload: part, + }, + }); + } + await sendMessage({ ...event.requestContext, message: { diff --git a/src/pubsub/publish.ts b/src/pubsub/publish.ts index 9d9c0c8..ad10d97 100644 --- a/src/pubsub/publish.ts +++ b/src/pubsub/publish.ts @@ -3,11 +3,11 @@ import { equals, ConditionExpression, } from '@aws/dynamodb-expressions'; -import { parse, execute } from 'graphql'; +import { parse, execute, ExecutionResult } from 'graphql'; import { MessageType } from 'graphql-ws'; -import { assign, Subscription } from '../model'; +import { Subscription } from '../model'; import { ServerClosure } from '../types'; -import { constructContext, sendMessage } from '../utils'; +import { constructContext, isAsyncIterable, sendMessage } from '../utils'; type PubSubEvent = { topic: string; @@ -17,24 +17,27 @@ type PubSubEvent = { export const publish = (c: ServerClosure) => async (event: PubSubEvent) => { const subscriptions = await getFilteredSubs(c)(event); const iters = subscriptions.map(async (sub) => { - const result = execute( - c.schema, - parse(sub.subscription.query), - event, - await constructContext(c)(sub), - sub.subscription.variables, - sub.subscription.operationName, - undefined - ); - - await sendMessage({ - ...sub.requestContext, - message: { - id: sub.subscriptionId, - type: MessageType.Next, - payload: await result, - }, + const result = execute({ + schema: c.schema, + document: parse(sub.subscription.query), + rootValue: event, + contextValue: await constructContext(c)(sub), + variableValues: sub.subscription.variables, + operationName: sub.subscription.operationName, }); + + // Support for @defer and @stream directives + const parts = isAsyncIterable(result) ? result : [result]; + for await (let part of parts) { + await sendMessage({ + ...sub.requestContext, + message: { + id: sub.subscriptionId, + type: MessageType.Next, + payload: part, + }, + }); + } }); return await Promise.all(iters); }; diff --git a/src/utils/index.ts b/src/utils/index.ts index 2346c94..fade44c 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,4 +1,5 @@ export * from './aws'; export * from './date'; +export * from './isAsyncIterable'; export * from './promise'; export * from './graphql'; diff --git a/src/utils/isAsyncIterable.ts b/src/utils/isAsyncIterable.ts new file mode 100644 index 0000000..65a12c4 --- /dev/null +++ b/src/utils/isAsyncIterable.ts @@ -0,0 +1,4 @@ +export const isAsyncIterable = (arg: any): arg is AsyncIterable => + arg !== null && + typeof arg == 'object' && + typeof arg[Symbol.asyncIterator] === 'function';