From 40196b50ae802a874e472763b2190440a5086ea5 Mon Sep 17 00:00:00 2001 From: Dave Shawley Date: Sat, 23 Mar 2024 14:29:59 -0400 Subject: [PATCH] Add UUID support Note that `typing.Annotated[uuid.UUID, ...]` doesn't work currently so I created a helper class `routing._UUID` that works around the defect. This will be fixed in 3.12.3 when it is released. I added an ignore statement for `ruff check` to allow access to internal names inside of tests. Otherwise testing `_UUID` required too many `noqa` comment directives for me. https://github.com/python/cpython/issues/115165 --- pyproject.toml | 3 +++ src/pydantictornado/routing.py | 26 +++++++++++++++++++++++++- tests/test_openapi.py | 10 +++++++--- tests/test_routing.py | 27 ++++++++++++++++++++++++++- tests/test_util.py | 16 ++++++---------- 5 files changed, 67 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc6c204..b39eab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,3 +95,6 @@ ignore = [ ] pycodestyle = {ignore-overlong-task-comments = true} select = ["ALL"] + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = ["SLF001"] diff --git a/src/pydantictornado/routing.py b/src/pydantictornado/routing.py index 6d7cf1c..df1a958 100644 --- a/src/pydantictornado/routing.py +++ b/src/pydantictornado/routing.py @@ -46,6 +46,27 @@ class ParameterAnnotation(pydantic.BaseModel): explode: bool | None = None +class _UUID(uuid.UUID): + """Wrapper class to work around defect in uuid.UUID + + This works around a defect in annotation process of + immutable values that will be fixed in 3.12.3. + + """ + + def __init__(self, *args: object, **kwargs: object) -> None: + if len(args) == 1 and isinstance(args[0], uuid.UUID): + super().__init__(int=args[0].int) + else: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + + def __setattr__(self, key: str, value: object) -> None: + # See https://github.com/python/cpython/issues/115165 + if key == '__orig_class__': + return + super().__setattr__(key, value) + + def _initialize_converters( m: collections.abc.MutableMapping[type, _PathConverter] ) -> None: @@ -64,7 +85,10 @@ def _initialize_converters( str: typing.Annotated[ lambda s: s, ParameterAnnotation(schema_={'type': 'string'}) ], - uuid.UUID: uuid.UUID, + uuid.UUID: typing.Annotated[ + _UUID, + ParameterAnnotation(schema_={'type': 'string', 'format': 'uuid'}), + ], datetime.date: typing.Annotated[ util.parse_date, ParameterAnnotation(schema_={'type': 'string', 'format': 'date'}), diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 0c915a8..994e068 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -272,7 +272,7 @@ def test_extras_at_multiple_levels(self) -> None: class OpenAPIRegexTests(unittest.TestCase): @staticmethod def translate_path_pattern(pattern: str) -> openapi.OpenAPIPath: - return openapi._translate_path_pattern(re.compile(pattern)) # noqa: SLF001 + return openapi._translate_path_pattern(re.compile(pattern)) def test_simple_paths(self) -> None: result = self.translate_path_pattern(r'/items/(?P.*)') @@ -412,7 +412,11 @@ async def status() -> dict[str, str]: return {} 'in': 'path', 'required': True, 'deprecated': False, - 'schema': {'type': 'string', 'pattern': '.*'}, + 'schema': { + 'type': 'string', + 'pattern': '.*', + 'format': 'uuid', + }, }, description['paths']['/items/{_id}']['parameters'][0], ) @@ -474,7 +478,7 @@ def test_mixed_annotations(self) -> None: async def op(item_id: IdType) -> IdType: return item_id - description = openapi._describe_path( # noqa: SLF001 private use ok + description = openapi._describe_path( routing.Route(r'/items/(?P\d+)', get=op), openapi.OpenAPIPath( path='/items/{item_id}', patterns={'item_id': r'\d+'} diff --git a/tests/test_routing.py b/tests/test_routing.py index fe72f36..a819412 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -2,6 +2,7 @@ import datetime import ipaddress import re +import typing import unittest.mock import uuid @@ -95,7 +96,7 @@ async def impl(*, _obj: cls) -> None: # type: ignore[valid-type] r = routing.Route(r'/(?P<_obj>.*)', get=impl) result = r.target_kwargs['path_types']['_obj'](str_value) self.assertTrue( - issubclass(cls, type(result)), + issubclass(type(result), cls), f'parsing {str_value!r} produced incompatible' f' type {type(result)}', ) @@ -209,3 +210,27 @@ async def str_impl(*, _id: str) -> None: # should not raise routing.Route(r'/(?P<_id>.*)', get=int_impl, delete=another_int_impl) + + +class UUIDWrapperTests(unittest.TestCase): + def test_that_construction_alternatives_are_supported(self) -> None: + value = uuid.uuid4() + for alternative in ('hex', 'bytes', 'bytes_le', 'fields', 'int'): + self.assertEqual( + value, + routing._UUID(**{alternative: getattr(value, alternative)}), + ) + + def test_that_uuid_copying_is_supported(self) -> None: + value = uuid.uuid4() + self.assertEqual(value, routing._UUID(value)) + + def test_that_uuid_can_be_annotated(self) -> None: + uuid_type = typing.Annotated[routing._UUID, 'whatever'] + self.assertEqual(uuid.UUID(int=0), uuid_type(int=0)) + + def test_that_uuid_is_still_immutable(self) -> None: + uuid_type = typing.Annotated[routing._UUID, 'whatever'] + uuid = uuid_type(int=0) + with self.assertRaises(TypeError): + uuid.version = 4 # type: ignore[misc] diff --git a/tests/test_util.py b/tests/test_util.py index 90d3cd4..f790f73 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -75,16 +75,14 @@ def test_that_non_types_fail(self) -> None: def test_that_clear_resets_cache(self) -> None: mapping = util.ClassMapping[str]() - mapping_cache = mapping._cache # noqa: SLF001 -- private access - mapping[int] = 'int' mapping[bool] = 'bool' mapping[float] = 'float' mapping.populate_cache() - self.assertEqual(len(mapping_cache), 3) + self.assertEqual(len(mapping._cache), 3) mapping.clear() - self.assertEqual(len(mapping_cache), 0) + self.assertEqual(len(mapping._cache), 0) def test_rebuild(self) -> None: def initializer(m: collections.abc.MutableMapping[type, str]) -> None: @@ -98,22 +96,20 @@ def initializer(m: collections.abc.MutableMapping[type, str]) -> None: self.assertEqual(len(mapping), 0) mapping = util.ClassMapping[str](initialize_data=initializer) - mapping_cache = mapping._cache # noqa: SLF001 -- private access - self.assertEqual(len(mapping), 3) - self.assertEqual(len(mapping_cache), 3) + self.assertEqual(len(mapping._cache), 3) mapping.clear() self.assertEqual(len(mapping), 0) - self.assertEqual(len(mapping_cache), 0) + self.assertEqual(len(mapping._cache), 0) mapping[float] = 'float' self.assertEqual(len(mapping), 1) - self.assertEqual(len(mapping_cache), 1) + self.assertEqual(len(mapping._cache), 1) mapping.rebuild() self.assertEqual(len(mapping), 3) - self.assertEqual(len(mapping_cache), 3) + self.assertEqual(len(mapping._cache), 3) class JSONSerializationTests(unittest.TestCase):