From bc806032989078b76a1764a962163626e7494d96 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Sat, 14 Nov 2020 19:18:11 +0100 Subject: [PATCH] improve type hint resolution #199 --- drf_spectacular/openapi.py | 27 +++++-------- drf_spectacular/plumbing.py | 71 +++++++++++++++++++++++++++++++--- tests/test_plumbing.py | 76 ++++++++++++++++++++++++++++++++++++- tests/test_warnings.py | 4 +- 4 files changed, 152 insertions(+), 26 deletions(-) diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 7d1ef4ca..56f24be2 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -22,10 +22,10 @@ from drf_spectacular.drainage import get_override, has_override from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension from drf_spectacular.plumbing import ( - ComponentRegistry, ResolvedComponent, anyisinstance, append_meta, build_array_type, - build_basic_type, build_choice_field, build_object_type, build_parameter_type, error, - follow_field_source, force_instance, get_doc, get_view_model, is_basic_type, is_field, - is_serializer, resolve_regex_path_parameter, safe_ref, warn, + ComponentRegistry, ResolvedComponent, UnableToProceedError, anyisinstance, append_meta, + build_array_type, build_basic_type, build_choice_field, build_object_type, build_parameter_type, + error, follow_field_source, force_instance, get_doc, get_view_model, is_basic_type, is_field, + is_serializer, resolve_regex_path_parameter, resolve_type_hint, safe_ref, warn, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes @@ -726,20 +726,13 @@ def _map_response_type_hint(self, method): if is_serializer(hint) or is_field(hint): return self._map_serializer_field(force_instance(hint), 'response') - elif is_basic_type(hint, allow_none=False): - return build_basic_type(hint) - elif getattr(hint, '__origin__', None) is typing.Union: - if type(None) == hint.__args__[1] and len(hint.__args__) == 2: - schema = build_basic_type(hint.__args__[0]) - schema['nullable'] = True - return schema - else: - warn(f'type hint {hint} not supported yet. defaulting to "string"') - return build_basic_type(OpenApiTypes.STR) - else: + + try: + return resolve_type_hint(hint) + except UnableToProceedError: warn( - f'type hint for function "{method.__name__}" is unknown. consider using ' - f'a type hint or @extend_schema_field. defaulting to string.' + f'unable to resolve type hint for function "{method.__name__}". consider ' + f'using a type hint or @extend_schema_field. defaulting to string.' ) return build_basic_type(OpenApiTypes.STR) diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 5bacf350..234af8df 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -2,6 +2,8 @@ import inspect import json import re +import sys +import typing import urllib.parse from abc import ABCMeta from collections import OrderedDict, defaultdict @@ -158,11 +160,13 @@ def build_basic_type(obj): return dict(OPENAPI_TYPE_MAPPING[OpenApiTypes.STR]) -def build_array_type(schema): - return { - 'type': 'array', - 'items': schema, - } +def build_array_type(schema, min_length=None, max_length=None): + schema = {'type': 'array', 'items': schema} + if min_length is not None: + schema['minLength'] = min_length + if max_length is not None: + schema['maxLength'] = max_length + return schema def build_object_type( @@ -731,3 +735,60 @@ def set_query_parameters(url, **kwargs) -> str: def get_relative_url(url: str) -> str: scheme, netloc, path, params, query, fragment = urllib.parse.urlparse(url) return urllib.parse.urlunparse(('', '', path, params, query, fragment)) + + +def _get_type_hint_origin(hint): + """ graceful fallback for py 3.8 typing functionality """ + if sys.version_info >= (3, 8): + return typing.get_origin(hint), typing.get_args(hint) + else: + origin = getattr(hint, '__origin__', None) + args = getattr(hint, '__args__', None) + origin = { + typing.List: list, + typing.Dict: dict, + typing.Tuple: tuple, + typing.Set: set, + typing.FrozenSet: frozenset + }.get(origin, origin) + return origin, args + + +def resolve_type_hint(hint): + """ resolve return value type hints to schema """ + origin, args = _get_type_hint_origin(hint) + + if origin is None and is_basic_type(hint, allow_none=False): + return build_basic_type(hint) + elif origin is list or hint is list: + return build_array_type(build_basic_type(args[0] if args else OpenApiTypes.OBJECT)) + elif origin is tuple: + return build_array_type( + schema=build_basic_type(args[0]), + max_length=len(args), + min_length=len(args), + ) + elif origin is dict or origin is defaultdict or origin is OrderedDict: + schema = build_basic_type(OpenApiTypes.OBJECT) + if args[1] is not typing.Any: + schema['additionalProperties'] = resolve_type_hint(args[1]) + return schema + elif origin is set: + return build_array_type(resolve_type_hint(args[0])) + elif origin is frozenset: + return build_array_type(resolve_type_hint(args[0])) + elif hasattr(typing, 'Literal') and origin is typing.Literal: + # python >= 3.8 + schema = {'enum': list(args)} + if all(type(args[0]) is type(choice) for choice in args): + schema.update(build_basic_type(type(args[0]))) + return schema + elif origin is typing.Union and len(args) == 2 and isinstance(None, args[1]): + # Optional[*] is resolved to Union[*, None] + schema = resolve_type_hint(args[0]) + schema['nullable'] = True + return schema + elif origin is typing.Union: + return {'oneOf': [resolve_type_hint(arg) for arg in args]} + else: + raise UnableToProceedError() diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py index cc72f000..fbf08704 100644 --- a/tests/test_plumbing.py +++ b/tests/test_plumbing.py @@ -1,12 +1,20 @@ +import json +import sys +import typing +from datetime import datetime + +import pytest from django.conf.urls import include from django.db import models from django.urls import re_path -from rest_framework import serializers +from rest_framework import generics, serializers from drf_spectacular.openapi import AutoSchema from drf_spectacular.plumbing import ( - detype_pattern, follow_field_source, force_instance, is_field, is_serializer, + detype_pattern, follow_field_source, force_instance, is_field, is_serializer, resolve_type_hint, ) +from drf_spectacular.validation import validate_schema +from tests import generate_schema def test_is_serializer(): @@ -69,3 +77,67 @@ def test_detype_patterns_with_module_includes(no_warnings): detype_pattern( pattern=re_path(r'^', include('tests.test_fields')) ) + + +TYPE_HINT_TEST_PARAMS = [ + ( + typing.Optional[int], + {'type': 'integer', 'nullable': True} + ), ( + typing.List[int], + {'type': 'array', 'items': {'type': 'integer'}} + ), ( + list, + {'type': 'array', 'items': {'type': 'object', 'additionalProperties': {}}} + ), ( + typing.Tuple[int, int, int], + {'type': 'array', 'items': {'type': 'integer'}, 'minLength': 3, 'maxLength': 3} + ), ( + typing.Set[datetime], + {'type': 'array', 'items': {'type': 'string', 'format': 'date-time'}} + ), ( + typing.FrozenSet[datetime], + {'type': 'array', 'items': {'type': 'string', 'format': 'date-time'}} + ), ( + typing.Dict[str, int], + {'type': 'object', 'additionalProperties': {'type': 'integer'}} + ), ( + typing.Dict[str, str], + {'type': 'object', 'additionalProperties': {'type': 'string'}} + ), ( + typing.Dict[str, typing.List[int]], + {'type': 'object', 'additionalProperties': {'type': 'array', 'items': {'type': 'integer'}}} + ), ( + dict, + {'type': 'object', 'additionalProperties': {}} + ), ( + typing.Union[int, str], + {'oneOf': [{'type': 'integer'}, {'type': 'string'}]} + ) +] + +if sys.version_info >= (3, 8): + TYPE_HINT_TEST_PARAMS.append(( + typing.Literal['x', 'y'], + {'enum': ['x', 'y'], 'type': 'string'} + )) + + +@pytest.mark.parametrize(['type_hint', 'ref_schema'], TYPE_HINT_TEST_PARAMS) +def test_type_hint_extraction(no_warnings, type_hint, ref_schema): + def func() -> type_hint: + pass # pragma: no cover + + # check expected resolution + schema = resolve_type_hint(typing.get_type_hints(func).get('return')) + assert json.dumps(schema) == json.dumps(ref_schema) + + # check schema validity + class XSerializer(serializers.Serializer): + x = serializers.SerializerMethodField() + XSerializer.get_x = func + + class XView(generics.RetrieveAPIView): + serializer_class = XSerializer + + validate_schema(generate_schema('/x', view=XView)) diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 8e66907d..de23b012 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -203,8 +203,8 @@ def get(self, request): generate_schema('x', view=XAPIView) stderr = capsys.readouterr().err - assert 'type hint for function "x" is unknown.' in stderr - assert 'type hint for function "get_y" is unknown.' in stderr + assert 'unable to resolve type hint for function "x"' in stderr + assert 'unable to resolve type hint for function "get_y"' in stderr def test_operation_id_collision_resolution(capsys):