From 040294877232a97e618fe3dcfdd09258aafa6c2f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 7 Apr 2024 21:44:37 +0200 Subject: [PATCH] rename AsyncPayloadRecord to IncrementalDataRecord Replicates graphql/graphql-js@b5813f06d419c24d8dd3165d7b8ed0914b78a423 --- docs/conf.py | 4 +- src/graphql/execution/execute.py | 275 +++++++++++++++++------------- tests/execution/test_customize.py | 9 +- tests/execution/test_stream.py | 7 +- 4 files changed, 166 insertions(+), 129 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f54580f2..6f719343 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -146,7 +146,6 @@ traceback types.TracebackType TypeMap -AsyncPayloadRecord AwaitableOrValue EnterLeaveVisitor ExperimentalIncrementalExecutionResults @@ -160,6 +159,7 @@ GraphQLTypeResolver GraphQLOutputType GroupedFieldSet +IncrementalDataRecord Middleware asyncio.events.AbstractEventLoop graphql.execution.collect_fields.FieldsAndPatches @@ -168,7 +168,7 @@ graphql.execution.execute.DeferredFragmentRecord graphql.execution.execute.ExperimentalIncrementalExecutionResults graphql.execution.execute.StreamArguments -graphql.execution.execute.StreamRecord +graphql.execution.execute.StreamItemsRecord graphql.language.lexer.EscapeSequence graphql.language.visitor.EnterLeaveVisitor graphql.type.definition.TContext diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 14a73199..dba46135 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -104,7 +104,7 @@ async def anext(iterator: AsyncIterator) -> Any: # noqa: A001 "execute_sync", "experimental_execute_incrementally", "subscribe", - "AsyncPayloadRecord", + "IncrementalDataRecord", "DeferredFragmentRecord", "StreamItemsRecord", "ExecutionResult", @@ -632,7 +632,7 @@ class ExecutionContext: type_resolver: GraphQLTypeResolver subscribe_field_resolver: GraphQLFieldResolver errors: list[GraphQLError] - subsequent_payloads: dict[AsyncPayloadRecord, None] # used as ordered set + subsequent_payloads: dict[IncrementalDataRecord, None] # used as ordered set middleware_manager: MiddlewareManager | None is_awaitable: Callable[[Any], TypeGuard[Awaitable]] = staticmethod( @@ -650,7 +650,7 @@ def __init__( field_resolver: GraphQLFieldResolver, type_resolver: GraphQLTypeResolver, subscribe_field_resolver: GraphQLFieldResolver, - subsequent_payloads: dict[AsyncPayloadRecord, None], + subsequent_payloads: dict[IncrementalDataRecord, None], errors: list[GraphQLError], middleware_manager: MiddlewareManager | None, is_awaitable: Callable[[Any], bool] | None, @@ -883,7 +883,7 @@ def execute_fields( source_value: Any, path: Path | None, fields: GroupedFieldSet, - async_payload_record: AsyncPayloadRecord | None = None, + incremental_data_record: IncrementalDataRecord | None = None, ) -> AwaitableOrValue[dict[str, Any]]: """Execute the given fields concurrently. @@ -897,7 +897,11 @@ def execute_fields( for response_name, field_group in fields.items(): field_path = Path(path, response_name, parent_type.name) result = self.execute_field( - parent_type, source_value, field_group, field_path, async_payload_record + parent_type, + source_value, + field_group, + field_path, + incremental_data_record, ) if result is not Undefined: results[response_name] = result @@ -934,7 +938,7 @@ def execute_field( source: Any, field_group: FieldGroup, path: Path, - async_payload_record: AsyncPayloadRecord | None = None, + incremental_data_record: IncrementalDataRecord | None = None, ) -> AwaitableOrValue[Any]: """Resolve the field on the given source object. @@ -970,11 +974,16 @@ def execute_field( if self.is_awaitable(result): return self.complete_awaitable_value( - return_type, field_group, info, path, result, async_payload_record + return_type, + field_group, + info, + path, + result, + incremental_data_record, ) completed = self.complete_value( - return_type, field_group, info, path, result, async_payload_record + return_type, field_group, info, path, result, incremental_data_record ) if self.is_awaitable(completed): # noinspection PyShadowingNames @@ -987,9 +996,9 @@ async def await_completed() -> Any: return_type, field_group, path, - async_payload_record, + incremental_data_record, ) - self.filter_subsequent_payloads(path, async_payload_record) + self.filter_subsequent_payloads(path, incremental_data_record) return None return await_completed() @@ -1000,9 +1009,9 @@ async def await_completed() -> Any: return_type, field_group, path, - async_payload_record, + incremental_data_record, ) - self.filter_subsequent_payloads(path, async_payload_record) + self.filter_subsequent_payloads(path, incremental_data_record) return None return completed @@ -1041,7 +1050,7 @@ def handle_field_error( return_type: GraphQLOutputType, field_group: FieldGroup, path: Path, - async_payload_record: AsyncPayloadRecord | None = None, + incremental_data_record: IncrementalDataRecord | None = None, ) -> None: """Handle error properly according to the field type.""" error = located_error(raw_error, field_group, path.as_list()) @@ -1051,7 +1060,9 @@ def handle_field_error( if is_non_null_type(return_type): raise error - errors = async_payload_record.errors if async_payload_record else self.errors + errors = ( + incremental_data_record.errors if incremental_data_record else self.errors + ) # Otherwise, error protection is applied, logging the error and resolving a # null value for this field if one is encountered. @@ -1064,7 +1075,7 @@ def complete_value( info: GraphQLResolveInfo, path: Path, result: Any, - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> AwaitableOrValue[Any]: """Complete a value. @@ -1101,7 +1112,7 @@ def complete_value( info, path, result, - async_payload_record, + incremental_data_record, ) if completed is None: msg = ( @@ -1118,7 +1129,7 @@ def complete_value( # If field type is List, complete each item in the list with inner type if is_list_type(return_type): return self.complete_list_value( - return_type, field_group, info, path, result, async_payload_record + return_type, field_group, info, path, result, incremental_data_record ) # If field type is a leaf type, Scalar or Enum, serialize to a valid value, @@ -1130,13 +1141,13 @@ def complete_value( # Object type and complete for that type. if is_abstract_type(return_type): return self.complete_abstract_value( - return_type, field_group, info, path, result, async_payload_record + return_type, field_group, info, path, result, incremental_data_record ) # If field type is Object, execute and complete all sub-selections. if is_object_type(return_type): return self.complete_object_value( - return_type, field_group, info, path, result, async_payload_record + return_type, field_group, info, path, result, incremental_data_record ) # Not reachable. All possible output types have been considered. @@ -1153,7 +1164,7 @@ async def complete_awaitable_value( info: GraphQLResolveInfo, path: Path, result: Any, - async_payload_record: AsyncPayloadRecord | None = None, + incremental_data_record: IncrementalDataRecord | None = None, ) -> Any: """Complete an awaitable value.""" try: @@ -1164,15 +1175,15 @@ async def complete_awaitable_value( info, path, resolved, - async_payload_record, + incremental_data_record, ) if self.is_awaitable(completed): completed = await completed except Exception as raw_error: self.handle_field_error( - raw_error, return_type, field_group, path, async_payload_record + raw_error, return_type, field_group, path, incremental_data_record ) - self.filter_subsequent_payloads(path, async_payload_record) + self.filter_subsequent_payloads(path, incremental_data_record) completed = None return completed @@ -1220,7 +1231,7 @@ async def complete_async_iterator_value( info: GraphQLResolveInfo, path: Path, async_iterator: AsyncIterator[Any], - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> list[Any]: """Complete an async iterator. @@ -1250,7 +1261,7 @@ async def complete_async_iterator_value( item_type, path, stream.label, - async_payload_record, + incremental_data_record, ) ), timeout=ASYNC_DELAY, @@ -1265,7 +1276,11 @@ async def complete_async_iterator_value( break except Exception as raw_error: self.handle_field_error( - raw_error, item_type, field_group, item_path, async_payload_record + raw_error, + item_type, + field_group, + item_path, + incremental_data_record, ) completed_results.append(None) break @@ -1276,7 +1291,7 @@ async def complete_async_iterator_value( field_group, info, item_path, - async_payload_record, + incremental_data_record, ): append_awaitable(index) @@ -1306,7 +1321,7 @@ def complete_list_value( info: GraphQLResolveInfo, path: Path, result: AsyncIterable[Any] | Iterable[Any], - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> AwaitableOrValue[list[Any]]: """Complete a list value. @@ -1318,7 +1333,12 @@ def complete_list_value( async_iterator = result.__aiter__() return self.complete_async_iterator_value( - item_type, field_group, info, path, async_iterator, async_payload_record + item_type, + field_group, + info, + path, + async_iterator, + incremental_data_record, ) if not is_iterable(result): @@ -1336,7 +1356,7 @@ def complete_list_value( complete_list_item_value = self.complete_list_item_value awaitable_indices: list[int] = [] append_awaitable = awaitable_indices.append - previous_async_payload_record = async_payload_record + previous_incremental_data_record = incremental_data_record completed_results: list[Any] = [] for index, item in enumerate(result): # No need to modify the info object containing the path, since from here on @@ -1348,7 +1368,7 @@ def complete_list_value( and isinstance(stream.initial_count, int) and index >= stream.initial_count ): - previous_async_payload_record = self.execute_stream_field( + previous_incremental_data_record = self.execute_stream_field( path, item_path, item, @@ -1356,7 +1376,7 @@ def complete_list_value( info, item_type, stream.label, - previous_async_payload_record, + previous_incremental_data_record, ) continue @@ -1367,7 +1387,7 @@ def complete_list_value( field_group, info, item_path, - async_payload_record, + incremental_data_record, ): append_awaitable(index) @@ -1400,7 +1420,7 @@ def complete_list_item_value( field_group: FieldGroup, info: GraphQLResolveInfo, item_path: Path, - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> bool: """Complete a list item value by adding it to the completed results. @@ -1411,7 +1431,12 @@ def complete_list_item_value( if is_awaitable(item): complete_results.append( self.complete_awaitable_value( - item_type, field_group, info, item_path, item, async_payload_record + item_type, + field_group, + info, + item_path, + item, + incremental_data_record, ) ) return True @@ -1423,7 +1448,7 @@ def complete_list_item_value( info, item_path, item, - async_payload_record, + incremental_data_record, ) if is_awaitable(completed_item): @@ -1437,9 +1462,11 @@ async def await_completed() -> Any: item_type, field_group, item_path, - async_payload_record, + incremental_data_record, + ) + self.filter_subsequent_payloads( + item_path, incremental_data_record ) - self.filter_subsequent_payloads(item_path, async_payload_record) return None complete_results.append(await_completed()) @@ -1453,9 +1480,9 @@ async def await_completed() -> Any: item_type, field_group, item_path, - async_payload_record, + incremental_data_record, ) - self.filter_subsequent_payloads(item_path, async_payload_record) + self.filter_subsequent_payloads(item_path, incremental_data_record) complete_results.append(None) return False @@ -1484,7 +1511,7 @@ def complete_abstract_value( info: GraphQLResolveInfo, path: Path, result: Any, - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> AwaitableOrValue[Any]: """Complete an abstract value. @@ -1510,7 +1537,7 @@ async def await_complete_object_value() -> Any: info, path, result, - async_payload_record, + incremental_data_record, ) if self.is_awaitable(value): return await value # type: ignore @@ -1527,7 +1554,7 @@ async def await_complete_object_value() -> Any: info, path, result, - async_payload_record, + incremental_data_record, ) def ensure_valid_runtime_type( @@ -1599,7 +1626,7 @@ def complete_object_value( info: GraphQLResolveInfo, path: Path, result: Any, - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> AwaitableOrValue[dict[str, Any]]: """Complete an Object value by executing all sub-selections.""" # If there is an `is_type_of()` predicate function, call it with the current @@ -1616,7 +1643,7 @@ async def execute_subfields_async() -> dict[str, Any]: return_type, result, field_group ) return self.collect_and_execute_subfields( - return_type, field_group, path, result, async_payload_record + return_type, field_group, path, result, incremental_data_record ) # type: ignore return execute_subfields_async() @@ -1625,7 +1652,7 @@ async def execute_subfields_async() -> dict[str, Any]: raise invalid_return_type_error(return_type, result, field_group) return self.collect_and_execute_subfields( - return_type, field_group, path, result, async_payload_record + return_type, field_group, path, result, incremental_data_record ) def collect_and_execute_subfields( @@ -1634,7 +1661,7 @@ def collect_and_execute_subfields( field_group: FieldGroup, path: Path, result: Any, - async_payload_record: AsyncPayloadRecord | None, + incremental_data_record: IncrementalDataRecord | None, ) -> AwaitableOrValue[dict[str, Any]]: """Collect sub-fields to execute to complete this value.""" sub_grouped_field_set, sub_patches = self.collect_subfields( @@ -1642,7 +1669,7 @@ def collect_and_execute_subfields( ) sub_fields = self.execute_fields( - return_type, result, path, sub_grouped_field_set, async_payload_record + return_type, result, path, sub_grouped_field_set, incremental_data_record ) for sub_patch in sub_patches: @@ -1653,7 +1680,7 @@ def collect_and_execute_subfields( sub_patch_field_nodes, label, path, - async_payload_record, + incremental_data_record, ) return sub_fields @@ -1731,13 +1758,15 @@ def execute_deferred_fragment( fields: GroupedFieldSet, label: str | None = None, path: Path | None = None, - parent_context: AsyncPayloadRecord | None = None, + parent_context: IncrementalDataRecord | None = None, ) -> None: """Execute deferred fragment.""" - async_payload_record = DeferredFragmentRecord(label, path, parent_context, self) + incremental_data_record = DeferredFragmentRecord( + label, path, parent_context, self + ) try: awaitable_or_data = self.execute_fields( - parent_type, source_value, path, fields, async_payload_record + parent_type, source_value, path, fields, incremental_data_record ) if self.is_awaitable(awaitable_or_data): @@ -1749,15 +1778,15 @@ async def await_data( try: return await awaitable except GraphQLError as error: - async_payload_record.errors.append(error) + incremental_data_record.errors.append(error) return None awaitable_or_data = await_data(awaitable_or_data) # type: ignore except GraphQLError as error: - async_payload_record.errors.append(error) + incremental_data_record.errors.append(error) awaitable_or_data = None - async_payload_record.add_data(awaitable_or_data) + incremental_data_record.add_data(awaitable_or_data) def execute_stream_field( self, @@ -1768,11 +1797,11 @@ def execute_stream_field( info: GraphQLResolveInfo, item_type: GraphQLOutputType, label: str | None = None, - parent_context: AsyncPayloadRecord | None = None, - ) -> AsyncPayloadRecord: + parent_context: IncrementalDataRecord | None = None, + ) -> IncrementalDataRecord: """Execute stream field.""" is_awaitable = self.is_awaitable - async_payload_record = StreamItemsRecord( + incremental_data_record = StreamItemsRecord( label, item_path, None, parent_context, self ) completed_item: Any @@ -1788,16 +1817,16 @@ async def await_completed_items() -> list[Any] | None: info, item_path, item, - async_payload_record, + incremental_data_record, ) ] except GraphQLError as error: - async_payload_record.errors.append(error) - self.filter_subsequent_payloads(path, async_payload_record) + incremental_data_record.errors.append(error) + self.filter_subsequent_payloads(path, incremental_data_record) return None - async_payload_record.add_items(await_completed_items()) - return async_payload_record + incremental_data_record.add_items(await_completed_items()) + return incremental_data_record try: try: @@ -1807,7 +1836,7 @@ async def await_completed_items() -> list[Any] | None: info, item_path, item, - async_payload_record, + incremental_data_record, ) completed_items: Any @@ -1825,15 +1854,17 @@ async def await_completed_items() -> list[Any] | None: item_type, field_group, item_path, - async_payload_record, + incremental_data_record, ) self.filter_subsequent_payloads( - item_path, async_payload_record + item_path, incremental_data_record ) return [None] except GraphQLError as error: # pragma: no cover - async_payload_record.errors.append(error) - self.filter_subsequent_payloads(path, async_payload_record) + incremental_data_record.errors.append(error) + self.filter_subsequent_payloads( + path, incremental_data_record + ) return None completed_items = await_completed_items() @@ -1846,18 +1877,18 @@ async def await_completed_items() -> list[Any] | None: item_type, field_group, item_path, - async_payload_record, + incremental_data_record, ) - self.filter_subsequent_payloads(item_path, async_payload_record) + self.filter_subsequent_payloads(item_path, incremental_data_record) completed_items = [None] except GraphQLError as error: - async_payload_record.errors.append(error) - self.filter_subsequent_payloads(item_path, async_payload_record) + incremental_data_record.errors.append(error) + self.filter_subsequent_payloads(item_path, incremental_data_record) completed_items = None - async_payload_record.add_items(completed_items) - return async_payload_record + incremental_data_record.add_items(completed_items) + return incremental_data_record async def execute_stream_async_iterator_item( self, @@ -1865,7 +1896,7 @@ async def execute_stream_async_iterator_item( field_group: FieldGroup, info: GraphQLResolveInfo, item_type: GraphQLOutputType, - async_payload_record: StreamItemsRecord, + incremental_data_record: StreamItemsRecord, item_path: Path, ) -> Any: """Execute stream iterator item.""" @@ -1874,7 +1905,7 @@ async def execute_stream_async_iterator_item( try: item = await anext(async_iterator) completed_item = self.complete_value( - item_type, field_group, info, item_path, item, async_payload_record + item_type, field_group, info, item_path, item, incremental_data_record ) return ( @@ -1884,14 +1915,14 @@ async def execute_stream_async_iterator_item( ) except StopAsyncIteration as raw_error: - async_payload_record.set_is_completed_async_iterator() + incremental_data_record.set_is_completed_async_iterator() raise StopAsyncIteration from raw_error except Exception as raw_error: self.handle_field_error( - raw_error, item_type, field_group, item_path, async_payload_record + raw_error, item_type, field_group, item_path, incremental_data_record ) - self.filter_subsequent_payloads(item_path, async_payload_record) + self.filter_subsequent_payloads(item_path, incremental_data_record) async def execute_stream_async_iterator( self, @@ -1902,16 +1933,16 @@ async def execute_stream_async_iterator( item_type: GraphQLOutputType, path: Path, label: str | None = None, - parent_context: AsyncPayloadRecord | None = None, + parent_context: IncrementalDataRecord | None = None, ) -> None: """Execute stream iterator.""" index = initial_index - previous_async_payload_record = parent_context + previous_incremental_data_record = parent_context while True: item_path = Path(path, index, None) - async_payload_record = StreamItemsRecord( - label, item_path, async_iterator, previous_async_payload_record, self + incremental_data_record = StreamItemsRecord( + label, item_path, async_iterator, previous_incremental_data_record, self ) try: @@ -1920,19 +1951,19 @@ async def execute_stream_async_iterator( field_group, info, item_type, - async_payload_record, + incremental_data_record, item_path, ) except StopAsyncIteration: - if async_payload_record.errors: - async_payload_record.add_items(None) # pragma: no cover + if incremental_data_record.errors: + incremental_data_record.add_items(None) # pragma: no cover else: - del self.subsequent_payloads[async_payload_record] + del self.subsequent_payloads[incremental_data_record] break except GraphQLError as error: - async_payload_record.errors.append(error) - self.filter_subsequent_payloads(path, async_payload_record) - async_payload_record.add_items(None) + incremental_data_record.errors.append(error) + self.filter_subsequent_payloads(path, incremental_data_record) + incremental_data_record.add_items(None) if async_iterator: # pragma: no cover else with suppress(Exception): await async_iterator.aclose() # type: ignore @@ -1941,65 +1972,65 @@ async def execute_stream_async_iterator( self._canceled_iterators.add(async_iterator) break - async_payload_record.add_items([data]) + incremental_data_record.add_items([data]) - previous_async_payload_record = async_payload_record + previous_incremental_data_record = incremental_data_record index += 1 def filter_subsequent_payloads( self, null_path: Path, - current_async_record: AsyncPayloadRecord | None = None, + current_incremental_data_record: IncrementalDataRecord | None = None, ) -> None: """Filter subsequent payloads.""" null_path_list = null_path.as_list() - for async_record in list(self.subsequent_payloads): - if async_record is current_async_record: + for incremental_data_record in list(self.subsequent_payloads): + if incremental_data_record is current_incremental_data_record: # don't remove payload from where error originates continue - if async_record.path[: len(null_path_list)] != null_path_list: - # async_record points to a path unaffected by this payload + if incremental_data_record.path[: len(null_path_list)] != null_path_list: + # incremental_data_record points to a path unaffected by this payload continue - # async_record path points to nulled error field + # incremental_data_record path points to nulled error field if ( - isinstance(async_record, StreamItemsRecord) - and async_record.async_iterator + isinstance(incremental_data_record, StreamItemsRecord) + and incremental_data_record.async_iterator ): - self._canceled_iterators.add(async_record.async_iterator) - del self.subsequent_payloads[async_record] + self._canceled_iterators.add(incremental_data_record.async_iterator) + del self.subsequent_payloads[incremental_data_record] def get_completed_incremental_results(self) -> list[IncrementalResult]: """Get completed incremental results.""" incremental_results: list[IncrementalResult] = [] append_result = incremental_results.append subsequent_payloads = list(self.subsequent_payloads) - for async_payload_record in subsequent_payloads: + for incremental_data_record in subsequent_payloads: incremental_result: IncrementalResult - if not async_payload_record.completed.is_set(): + if not incremental_data_record.completed.is_set(): continue - del self.subsequent_payloads[async_payload_record] - if isinstance(async_payload_record, StreamItemsRecord): - items = async_payload_record.items - if async_payload_record.is_completed_async_iterator: + del self.subsequent_payloads[incremental_data_record] + if isinstance(incremental_data_record, StreamItemsRecord): + items = incremental_data_record.items + if incremental_data_record.is_completed_async_iterator: # async iterable resolver finished but there may be pending payload continue # pragma: no cover incremental_result = IncrementalStreamResult( items, - async_payload_record.errors - if async_payload_record.errors + incremental_data_record.errors + if incremental_data_record.errors else None, - async_payload_record.path, - async_payload_record.label, + incremental_data_record.path, + incremental_data_record.label, ) else: - data = async_payload_record.data + data = incremental_data_record.data incremental_result = IncrementalDeferResult( data, - async_payload_record.errors - if async_payload_record.errors + incremental_data_record.errors + if incremental_data_record.errors else None, - async_payload_record.path, - async_payload_record.label, + incremental_data_record.path, + incremental_data_record.label, ) append_result(incremental_result) @@ -2604,7 +2635,7 @@ class DeferredFragmentRecord: label: str | None path: list[str | int] data: dict[str, Any] | None - parent_context: AsyncPayloadRecord | None + parent_context: IncrementalDataRecord | None completed: Event _context: ExecutionContext _data: AwaitableOrValue[dict[str, Any] | None] @@ -2614,7 +2645,7 @@ def __init__( self, label: str | None, path: Path | None, - parent_context: AsyncPayloadRecord | None, + parent_context: IncrementalDataRecord | None, context: ExecutionContext, ) -> None: self.label = label @@ -2669,7 +2700,7 @@ class StreamItemsRecord: label: str | None path: list[str | int] items: list[str] | None - parent_context: AsyncPayloadRecord | None + parent_context: IncrementalDataRecord | None async_iterator: AsyncIterator[Any] | None is_completed_async_iterator: bool completed: Event @@ -2682,7 +2713,7 @@ def __init__( label: str | None, path: Path | None, async_iterator: AsyncIterator[Any] | None, - parent_context: AsyncPayloadRecord | None, + parent_context: IncrementalDataRecord | None, context: ExecutionContext, ) -> None: self.label = label @@ -2738,4 +2769,4 @@ def set_is_completed_async_iterator(self) -> None: self._items_added.set() -AsyncPayloadRecord = Union[DeferredFragmentRecord, StreamItemsRecord] +IncrementalDataRecord = Union[DeferredFragmentRecord, StreamItemsRecord] diff --git a/tests/execution/test_customize.py b/tests/execution/test_customize.py index 6d8cd369..23740237 100644 --- a/tests/execution/test_customize.py +++ b/tests/execution/test_customize.py @@ -43,10 +43,15 @@ def uses_a_custom_execution_context_class(): class TestExecutionContext(ExecutionContext): def execute_field( - self, parent_type, source, field_group, path, async_payload_record=None + self, + parent_type, + source, + field_group, + path, + incremental_data_record=None, ): result = super().execute_field( - parent_type, source, field_group, path, async_payload_record + parent_type, source, field_group, path, incremental_data_record ) return result * 2 # type: ignore diff --git a/tests/execution/test_stream.py b/tests/execution/test_stream.py index 091484e2..b8c722a2 100644 --- a/tests/execution/test_stream.py +++ b/tests/execution/test_stream.py @@ -176,16 +176,17 @@ def can_print_stream_record(): context = ExecutionContext.build(schema, parse("{ hero { id } }")) assert isinstance(context, ExecutionContext) record = StreamItemsRecord(None, None, None, None, context) - assert str(record) == "StreamRecord(path=[])" + assert str(record) == "StreamItemsRecord(path=[])" record = StreamItemsRecord( "foo", Path(None, "bar", "Bar"), None, record, context ) assert ( - str(record) == "StreamRecord(" "path=['bar'], label='foo', parent_context)" + str(record) == "StreamItemsRecord(" + "path=['bar'], label='foo', parent_context)" ) record.items = ["hello", "world"] assert ( - str(record) == "StreamRecord(" + str(record) == "StreamItemsRecord(" "path=['bar'], label='foo', parent_context, items)" )