Skip to content

Commit

Permalink
Make strawberry.Private compatible with PEP-563
Browse files Browse the repository at this point in the history
This commit Closes strawberry-graphql#1586 by checking for private fields at schema
conversion time. As per [PEP-563], in the future Python annotations will
no longer be evaluated at definition time. As reported by Issue strawberry-graphql#1586,
this means that `strawberry.Private` is incompatible with postponed
evaluation as the check for a private field requires an evaluated
type-hint annotation to work.

By checking for private fields at schema conversion time, it is
guaranteed that all fields that should be included in the schema are
resolvable using an eval. This ensures that the current approach for
defining private fields can be left intact.

The current filtering for fields annotated with `strawberry.Private` in
`types.type_resolver.py` are left intact to not needlessly instantiate
`StrawberryField` objects when `strawberry.Private` is resolvable.

Summary of Changes:
- Added check for private fields at schema evaluation time
- Added test to check that postponed evaluation with
  `strawberry.Private` functions correctly
- Reduced code duplication in `schema_converter.py`

[PEP-563]: https://www.python.org/dev/peps/pep-0563/
  • Loading branch information
skilkis committed Mar 7, 2022
1 parent c6093b0 commit ee20f83
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 32 deletions.
80 changes: 48 additions & 32 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from graphql import (
GraphQLArgument,
Expand Down Expand Up @@ -34,6 +45,7 @@
)
from strawberry.field import StrawberryField
from strawberry.lazy_type import LazyType
from strawberry.private import is_private
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.types.scalar import _make_scalar_type
from strawberry.type import StrawberryList, StrawberryOptional, StrawberryType
Expand Down Expand Up @@ -163,6 +175,38 @@ def from_input_field(self, field: StrawberryField) -> GraphQLInputField:
deprecation_reason=field.deprecation_reason,
)

FieldType = TypeVar("FieldType", GraphQLField, GraphQLInputField)

@staticmethod
def _get_thunk_mapping(
fields: List[StrawberryField],
name_converter: Callable[[StrawberryField], str],
field_converter: Callable[[StrawberryField], FieldType],
) -> Dict[str, FieldType]:
return {
name_converter(f): field_converter(f)
for f in fields
if not is_private(f.type)
}

def get_graphql_fields(
self, type_definition: TypeDefinition
) -> Dict[str, GraphQLField]:
return self._get_thunk_mapping(
fields=type_definition.fields,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_field,
)

def get_graphql_input_fields(
self, type_definition: TypeDefinition
) -> Dict[str, GraphQLInputField]:
return self._get_thunk_mapping(
fields=type_definition.fields,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_input_field,
)

def from_input_object(self, object_type: type) -> GraphQLInputObjectType:
type_definition = object_type._type_definition # type: ignore

Expand All @@ -174,18 +218,9 @@ def from_input_object(self, object_type: type) -> GraphQLInputObjectType:
assert isinstance(graphql_object_type, GraphQLInputObjectType) # For mypy
return graphql_object_type

def get_graphql_fields() -> Dict[str, GraphQLInputField]:
graphql_fields = {}
for field in type_definition.fields:
field_name = self.config.name_converter.from_field(field)

graphql_fields[field_name] = self.from_input_field(field)

return graphql_fields

graphql_object_type = GraphQLInputObjectType(
name=type_name,
fields=get_graphql_fields,
fields=lambda: self.get_graphql_input_fields(type_definition),
description=type_definition.description,
)

Expand All @@ -206,18 +241,9 @@ def from_interface(self, interface: TypeDefinition) -> GraphQLInterfaceType:
assert isinstance(graphql_interface, GraphQLInterfaceType) # For mypy
return graphql_interface

def get_graphql_fields() -> Dict[str, GraphQLField]:
graphql_fields = {}

for field in interface.fields:
field_name = self.config.name_converter.from_field(field)
graphql_fields[field_name] = self.from_field(field)

return graphql_fields

graphql_interface = GraphQLInterfaceType(
name=interface_name,
fields=get_graphql_fields,
fields=lambda: self.get_graphql_fields(interface),
interfaces=list(map(self.from_interface, interface.interfaces)),
description=interface.description,
)
Expand All @@ -243,16 +269,6 @@ def from_object(self, object_type: TypeDefinition) -> GraphQLObjectType:
assert isinstance(graphql_object_type, GraphQLObjectType) # For mypy
return graphql_object_type

def get_graphql_fields() -> Dict[str, GraphQLField]:
graphql_fields = {}

for field in object_type.fields:
field_name = self.config.name_converter.from_field(field)

graphql_fields[field_name] = self.from_field(field)

return graphql_fields

is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]]
if object_type.is_type_of:
is_type_of = object_type.is_type_of
Expand All @@ -266,7 +282,7 @@ def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool:

graphql_object_type = GraphQLObjectType(
name=object_type_name,
fields=get_graphql_fields,
fields=lambda: self.get_graphql_fields(object_type),
interfaces=list(map(self.from_interface, object_type.interfaces)),
description=object_type.description,
is_type_of=is_type_of,
Expand Down
27 changes: 27 additions & 0 deletions tests/schema/test_private_field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import root
import pytest

import strawberry
Expand Down Expand Up @@ -55,3 +56,29 @@ def age_in_months(self) -> int:
assert result.data == {
"ageInMonths": 84,
}


def test_private_field_with_str_annotations():
"""Check compatibility of strawberry.Private with annotations as string."""

from dataclasses import dataclass

@strawberry.type
class Query:
not_seen: "strawberry.Private[SensitiveData]"

@strawberry.field
def accesible_info(self) -> str:
return self.not_seen.info

@dataclass
class SensitiveData:
value: int
info: str

schema = strawberry.Schema(query=Query)

result = schema.execute_sync(
"query { accesibleInfo }", root_value=Query(not_seen=SensitiveData(1, "foo"))
)
assert result.data == {"accesibleInfo": "foo"}

0 comments on commit ee20f83

Please sign in to comment.