Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for @defer directive #213

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
GraphQLSkipDirective,
GraphQLIncludeDirective,
GraphQLDeprecatedDirective,
GraphQLDeferDirective,
# Constant Deprecation Reason
DEFAULT_DEPRECATION_REASON,
# GraphQL Types for introspection.
Expand Down Expand Up @@ -198,6 +199,7 @@
"GraphQLSkipDirective",
"GraphQLIncludeDirective",
"GraphQLDeprecatedDirective",
"GraphQLDeferDirective",
"DEFAULT_DEPRECATION_REASON",
"TypeKind",
"DirectiveLocation",
Expand Down
84 changes: 64 additions & 20 deletions graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GraphQLScalarType,
GraphQLSchema,
GraphQLUnionType,
GraphQLDeferDirective,
)
from .base import (
ExecutionContext,
Expand Down Expand Up @@ -63,7 +64,8 @@ def execute(
executor=None, # type: Any
return_promise=False, # type: bool
middleware=None, # type: Optional[Any]
allow_subscriptions=False, # type: bool
allow_subscriptions=False, # type: bool,
deferred_results = None, #type: Optional[List[Tuple[String, Promise[ExecutionResult]]]]
**options # type: Any
):
# type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]]
Expand Down Expand Up @@ -117,11 +119,16 @@ def execute(
executor,
middleware,
allow_subscriptions,
)
)

if deferred_results is not None:
deferred = []
else:
deferred = None

def promise_executor(v):
# type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable]
return execute_operation(exe_context, exe_context.operation, root)
return execute_operation(exe_context, exe_context.operation, root, deferred)

def on_rejected(error):
# type: (Exception) -> None
Expand All @@ -142,6 +149,19 @@ def on_resolve(data):
Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve)
)

def on_deferred_resolve(data, errors):
if len(errors) == 0:
return ExecutionResult(data=data)
return ExecutionResult(data=data, errors=errors)

if deferred_results is not None:
for path, deferred_promise in deferred:
errors = []
deferred_results.append(
(path, deferred_promise
.catch(errors.append)
.then(functools.partial(on_deferred_resolve, errors=errors))))

if not return_promise:
exe_context.executor.wait_until_finished()
return promise.get()
Expand All @@ -157,6 +177,7 @@ def execute_operation(
exe_context, # type: ExecutionContext
operation, # type: OperationDefinition
root_value, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Union[Dict, Promise[Dict]]
type = get_operation_root_type(exe_context.schema, operation)
Expand All @@ -176,7 +197,14 @@ def execute_operation(
)
return subscribe_fields(exe_context, type, root_value, fields)

return execute_fields(exe_context, type, root_value, fields, [], None)
if deferred is None:
deferred = []

result = execute_fields(exe_context, type, root_value, fields, [], None, deferred)
# if len(deferred) > 0:
# return Promise.all((result, deferred))
# else:
return result


def execute_fields_serially(
Expand All @@ -197,6 +225,7 @@ def execute_field_callback(results, response_name):
field_asts,
None,
path + [response_name],
[]
)
if result is Undefined:
return results
Expand Down Expand Up @@ -231,6 +260,7 @@ def execute_fields(
fields, # type: DefaultOrderedDict
path, # type: List[Union[int, str]]
info, # type: Optional[ResolveInfo]
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Union[Dict, Promise[Dict]]
contains_promise = False
Expand All @@ -245,13 +275,20 @@ def execute_fields(
field_asts,
info,
path + [response_name],
deferred
)
if result is Undefined:
continue

final_results[response_name] = result
if is_thenable(result):
contains_promise = True
if deferred is not None and any(
d.name.value == GraphQLDeferDirective.name for a in field_asts for d in a.directives):
final_results[response_name] = None
deferred.append((path + [response_name], Promise.resolve(result)))
else:
if is_thenable(result):
contains_promise = True

final_results[response_name] = result

if not contains_promise:
return final_results
Expand Down Expand Up @@ -316,6 +353,7 @@ def resolve_field(
field_asts, # type: List[Field]
parent_info, # type: Optional[ResolveInfo]
field_path, # type: List[Union[int, str]]
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Any
field_ast = field_asts[0]
Expand Down Expand Up @@ -360,7 +398,7 @@ def resolve_field(
result = resolve_or_error(resolve_fn_middleware, source, info, args, executor)

return complete_value_catching_error(
exe_context, return_type, field_asts, info, field_path, result
exe_context, return_type, field_asts, info, field_path, result, deferred
)


Expand Down Expand Up @@ -462,18 +500,19 @@ def complete_value_catching_error(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Any
# If the field type is non-nullable, then it is resolved without any
# protection from errors.
if isinstance(return_type, GraphQLNonNull):
return complete_value(exe_context, return_type, field_asts, info, path, result)
return complete_value(exe_context, return_type, field_asts, info, path, result, deferred)

# Otherwise, error protection is applied, logging the error and
# resolving a null value for this field if one is encountered.
try:
completed = complete_value(
exe_context, return_type, field_asts, info, path, result
exe_context, return_type, field_asts, info, path, result, deferred
)
if is_thenable(completed):

Expand All @@ -499,6 +538,7 @@ def complete_value(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Any
"""
Expand All @@ -524,7 +564,7 @@ def complete_value(
if is_thenable(result):
return Promise.resolve(result).then(
lambda resolved: complete_value(
exe_context, return_type, field_asts, info, path, resolved
exe_context, return_type, field_asts, info, path, resolved, deferred
),
lambda error: Promise.rejected(
GraphQLLocatedError(field_asts, original_error=error, path=path)
Expand All @@ -537,7 +577,7 @@ def complete_value(

if isinstance(return_type, GraphQLNonNull):
return complete_nonnull_value(
exe_context, return_type, field_asts, info, path, result
exe_context, return_type, field_asts, info, path, result, deferred
)

# If result is null-like, return null.
Expand All @@ -547,7 +587,7 @@ def complete_value(
# If field type is List, complete each item in the list with the inner type
if isinstance(return_type, GraphQLList):
return complete_list_value(
exe_context, return_type, field_asts, info, path, result
exe_context, return_type, field_asts, info, path, result, deferred
)

# If field type is Scalar or Enum, serialize to a valid value, returning
Expand All @@ -557,12 +597,12 @@ def complete_value(

if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)):
return complete_abstract_value(
exe_context, return_type, field_asts, info, path, result
exe_context, return_type, field_asts, info, path, result, deferred
)

if isinstance(return_type, GraphQLObjectType):
return complete_object_value(
exe_context, return_type, field_asts, info, path, result
exe_context, return_type, field_asts, info, path, result, deferred
)

assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type)
Expand All @@ -575,13 +615,14 @@ def complete_list_value(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> List[Any]
"""
Complete a list value by completing each item in the list with the inner type
"""
assert isinstance(result, collections.Iterable), (
"User Error: expected iterable, but did not find one " + "for field {}.{}."
"User Error: expected iterable, but did not find one " + "for field {}.{}."
).format(info.parent_type, info.field_name)

item_type = return_type.of_type
Expand All @@ -591,7 +632,7 @@ def complete_list_value(
index = 0
for item in result:
completed_item = complete_value_catching_error(
exe_context, item_type, field_asts, info, path + [index], item
exe_context, item_type, field_asts, info, path + [index], item, deferred
)
if not contains_promise and is_thenable(completed_item):
contains_promise = True
Expand Down Expand Up @@ -631,6 +672,7 @@ def complete_abstract_value(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Dict[str, Any]
"""
Expand Down Expand Up @@ -669,7 +711,7 @@ def complete_abstract_value(
)

return complete_object_value(
exe_context, runtime_type, field_asts, info, path, result
exe_context, runtime_type, field_asts, info, path, result, deferred
)


Expand All @@ -693,6 +735,7 @@ def complete_object_value(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Dict[str, Any]
"""
Expand All @@ -708,7 +751,7 @@ def complete_object_value(

# Collect sub-fields to execute to complete this value.
subfield_asts = exe_context.get_sub_fields(return_type, field_asts)
return execute_fields(exe_context, return_type, result, subfield_asts, path, info)
return execute_fields(exe_context, return_type, result, subfield_asts, path, info, deferred)


def complete_nonnull_value(
Expand All @@ -718,13 +761,14 @@ def complete_nonnull_value(
info, # type: ResolveInfo
path, # type: List[Union[int, str]]
result, # type: Any
deferred, #type: Optional[List[Promise]]
):
# type: (...) -> Any
"""
Complete a NonNull value by completing the inner type
"""
completed = complete_value(
exe_context, return_type.of_type, field_asts, info, path, result
exe_context, return_type.of_type, field_asts, info, path, result, deferred
)
if completed is None:
raise GraphQLError(
Expand Down
1 change: 1 addition & 0 deletions graphql/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GraphQLSkipDirective,
GraphQLIncludeDirective,
GraphQLDeprecatedDirective,
GraphQLDeferDirective,
# Constant Deprecation Reason
DEFAULT_DEPRECATION_REASON,
)
Expand Down
12 changes: 12 additions & 0 deletions graphql/type/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,20 @@ def __init__(self, name, description=None, args=None, locations=None):
locations=[DirectiveLocation.FIELD_DEFINITION, DirectiveLocation.ENUM_VALUE],
)


"""Used to defer the result of an element."""
GraphQLDeferDirective = GraphQLDirective(
name="defer",
description='Defers this field',
args={},
locations=[
DirectiveLocation.FIELD
],
)

specified_directives = [
GraphQLIncludeDirective,
GraphQLSkipDirective,
GraphQLDeprecatedDirective,
GraphQLDeferDirective,
]
7 changes: 7 additions & 0 deletions graphql/utils/build_ast_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..type import (
GraphQLArgument,
GraphQLBoolean,
GraphQLDeferDirective,
GraphQLDeprecatedDirective,
GraphQLDirective,
GraphQLEnumType,
Expand Down Expand Up @@ -308,6 +309,9 @@ def make_input_object_def(definition):
find_deprecated_directive = (
directive.name for directive in directives if directive.name == "deprecated"
)
find_defer_directive = (
directive.name for directive in directives if directive.name == "defer"
)

if not next(find_skip_directive, None):
directives.append(GraphQLSkipDirective)
Expand All @@ -318,6 +322,9 @@ def make_input_object_def(definition):
if not next(find_deprecated_directive, None):
directives.append(GraphQLDeprecatedDirective)

if not next(find_defer_directive, None):
directives.append(GraphQLDeferDirective)

schema_kwargs = {"query": get_object_type(ast_map[query_type_name])}

if mutation_type_name:
Expand Down
2 changes: 1 addition & 1 deletion graphql/utils/schema_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def print_introspection_schema(schema):

def is_spec_directive(directive_name):
# type: (str) -> bool
return directive_name in ("skip", "include", "deprecated")
return directive_name in ("skip", "include", "deprecated", "defer")


def _is_defined_type(typename):
Expand Down
Loading