Skip to content

Commit

Permalink
improve type hint resolution #199
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Nov 14, 2020
1 parent 05e4b90 commit bc80603
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 26 deletions.
27 changes: 10 additions & 17 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from drf_spectacular.drainage import get_override, has_override
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import (
ComponentRegistry, ResolvedComponent, anyisinstance, append_meta, build_array_type,
build_basic_type, build_choice_field, build_object_type, build_parameter_type, error,
follow_field_source, force_instance, get_doc, get_view_model, is_basic_type, is_field,
is_serializer, resolve_regex_path_parameter, safe_ref, warn,
ComponentRegistry, ResolvedComponent, UnableToProceedError, anyisinstance, append_meta,
build_array_type, build_basic_type, build_choice_field, build_object_type, build_parameter_type,
error, follow_field_source, force_instance, get_doc, get_view_model, is_basic_type, is_field,
is_serializer, resolve_regex_path_parameter, resolve_type_hint, safe_ref, warn,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down Expand Up @@ -726,20 +726,13 @@ def _map_response_type_hint(self, method):

if is_serializer(hint) or is_field(hint):
return self._map_serializer_field(force_instance(hint), 'response')
elif is_basic_type(hint, allow_none=False):
return build_basic_type(hint)
elif getattr(hint, '__origin__', None) is typing.Union:
if type(None) == hint.__args__[1] and len(hint.__args__) == 2:
schema = build_basic_type(hint.__args__[0])
schema['nullable'] = True
return schema
else:
warn(f'type hint {hint} not supported yet. defaulting to "string"')
return build_basic_type(OpenApiTypes.STR)
else:

try:
return resolve_type_hint(hint)
except UnableToProceedError:
warn(
f'type hint for function "{method.__name__}" is unknown. consider using '
f'a type hint or @extend_schema_field. defaulting to string.'
f'unable to resolve type hint for function "{method.__name__}". consider '
f'using a type hint or @extend_schema_field. defaulting to string.'
)
return build_basic_type(OpenApiTypes.STR)

Expand Down
71 changes: 66 additions & 5 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import inspect
import json
import re
import sys
import typing
import urllib.parse
from abc import ABCMeta
from collections import OrderedDict, defaultdict
Expand Down Expand Up @@ -158,11 +160,13 @@ def build_basic_type(obj):
return dict(OPENAPI_TYPE_MAPPING[OpenApiTypes.STR])


def build_array_type(schema):
return {
'type': 'array',
'items': schema,
}
def build_array_type(schema, min_length=None, max_length=None):
schema = {'type': 'array', 'items': schema}
if min_length is not None:
schema['minLength'] = min_length
if max_length is not None:
schema['maxLength'] = max_length
return schema


def build_object_type(
Expand Down Expand Up @@ -731,3 +735,60 @@ def set_query_parameters(url, **kwargs) -> str:
def get_relative_url(url: str) -> str:
scheme, netloc, path, params, query, fragment = urllib.parse.urlparse(url)
return urllib.parse.urlunparse(('', '', path, params, query, fragment))


def _get_type_hint_origin(hint):
""" graceful fallback for py 3.8 typing functionality """
if sys.version_info >= (3, 8):
return typing.get_origin(hint), typing.get_args(hint)
else:
origin = getattr(hint, '__origin__', None)
args = getattr(hint, '__args__', None)
origin = {
typing.List: list,
typing.Dict: dict,
typing.Tuple: tuple,
typing.Set: set,
typing.FrozenSet: frozenset
}.get(origin, origin)
return origin, args


def resolve_type_hint(hint):
""" resolve return value type hints to schema """
origin, args = _get_type_hint_origin(hint)

if origin is None and is_basic_type(hint, allow_none=False):
return build_basic_type(hint)
elif origin is list or hint is list:
return build_array_type(build_basic_type(args[0] if args else OpenApiTypes.OBJECT))
elif origin is tuple:
return build_array_type(
schema=build_basic_type(args[0]),
max_length=len(args),
min_length=len(args),
)
elif origin is dict or origin is defaultdict or origin is OrderedDict:
schema = build_basic_type(OpenApiTypes.OBJECT)
if args[1] is not typing.Any:
schema['additionalProperties'] = resolve_type_hint(args[1])
return schema
elif origin is set:
return build_array_type(resolve_type_hint(args[0]))
elif origin is frozenset:
return build_array_type(resolve_type_hint(args[0]))
elif hasattr(typing, 'Literal') and origin is typing.Literal:
# python >= 3.8
schema = {'enum': list(args)}
if all(type(args[0]) is type(choice) for choice in args):
schema.update(build_basic_type(type(args[0])))
return schema
elif origin is typing.Union and len(args) == 2 and isinstance(None, args[1]):
# Optional[*] is resolved to Union[*, None]
schema = resolve_type_hint(args[0])
schema['nullable'] = True
return schema
elif origin is typing.Union:
return {'oneOf': [resolve_type_hint(arg) for arg in args]}
else:
raise UnableToProceedError()
76 changes: 74 additions & 2 deletions tests/test_plumbing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import json
import sys
import typing
from datetime import datetime

import pytest
from django.conf.urls import include
from django.db import models
from django.urls import re_path
from rest_framework import serializers
from rest_framework import generics, serializers

from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
detype_pattern, follow_field_source, force_instance, is_field, is_serializer,
detype_pattern, follow_field_source, force_instance, is_field, is_serializer, resolve_type_hint,
)
from drf_spectacular.validation import validate_schema
from tests import generate_schema


def test_is_serializer():
Expand Down Expand Up @@ -69,3 +77,67 @@ def test_detype_patterns_with_module_includes(no_warnings):
detype_pattern(
pattern=re_path(r'^', include('tests.test_fields'))
)


TYPE_HINT_TEST_PARAMS = [
(
typing.Optional[int],
{'type': 'integer', 'nullable': True}
), (
typing.List[int],
{'type': 'array', 'items': {'type': 'integer'}}
), (
list,
{'type': 'array', 'items': {'type': 'object', 'additionalProperties': {}}}
), (
typing.Tuple[int, int, int],
{'type': 'array', 'items': {'type': 'integer'}, 'minLength': 3, 'maxLength': 3}
), (
typing.Set[datetime],
{'type': 'array', 'items': {'type': 'string', 'format': 'date-time'}}
), (
typing.FrozenSet[datetime],
{'type': 'array', 'items': {'type': 'string', 'format': 'date-time'}}
), (
typing.Dict[str, int],
{'type': 'object', 'additionalProperties': {'type': 'integer'}}
), (
typing.Dict[str, str],
{'type': 'object', 'additionalProperties': {'type': 'string'}}
), (
typing.Dict[str, typing.List[int]],
{'type': 'object', 'additionalProperties': {'type': 'array', 'items': {'type': 'integer'}}}
), (
dict,
{'type': 'object', 'additionalProperties': {}}
), (
typing.Union[int, str],
{'oneOf': [{'type': 'integer'}, {'type': 'string'}]}
)
]

if sys.version_info >= (3, 8):
TYPE_HINT_TEST_PARAMS.append((
typing.Literal['x', 'y'],
{'enum': ['x', 'y'], 'type': 'string'}
))


@pytest.mark.parametrize(['type_hint', 'ref_schema'], TYPE_HINT_TEST_PARAMS)
def test_type_hint_extraction(no_warnings, type_hint, ref_schema):
def func() -> type_hint:
pass # pragma: no cover

# check expected resolution
schema = resolve_type_hint(typing.get_type_hints(func).get('return'))
assert json.dumps(schema) == json.dumps(ref_schema)

# check schema validity
class XSerializer(serializers.Serializer):
x = serializers.SerializerMethodField()
XSerializer.get_x = func

class XView(generics.RetrieveAPIView):
serializer_class = XSerializer

validate_schema(generate_schema('/x', view=XView))
4 changes: 2 additions & 2 deletions tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def get(self, request):

generate_schema('x', view=XAPIView)
stderr = capsys.readouterr().err
assert 'type hint for function "x" is unknown.' in stderr
assert 'type hint for function "get_y" is unknown.' in stderr
assert 'unable to resolve type hint for function "x"' in stderr
assert 'unable to resolve type hint for function "get_y"' in stderr


def test_operation_id_collision_resolution(capsys):
Expand Down

0 comments on commit bc80603

Please sign in to comment.