From 0a3b6c4048c58e248413b894790b61cc51cbdc74 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 10 Aug 2022 18:25:58 +0100 Subject: [PATCH 01/11] WIP --- scripts-dev/check_pydantic_models.py | 151 +++++++++++++++++++++++++++ scripts-dev/lint.sh | 1 + 2 files changed, 152 insertions(+) create mode 100755 scripts-dev/check_pydantic_models.py diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py new file mode 100755 index 000000000000..4539db5d190f --- /dev/null +++ b/scripts-dev/check_pydantic_models.py @@ -0,0 +1,151 @@ +#! /usr/bin/env python +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import contextlib +import functools +import sys +import textwrap +import unittest.mock +from contextlib import contextmanager +from typing import Generator, Any + +from pydantic import confloat, conint, conbytes, constr + +CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG = [ + constr, + conbytes, + conint, + confloat, +] + + +@contextmanager +def monkeypatch_pydantic() -> Generator[None, None, None]: + with contextlib.ExitStack() as patches: + for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: + + @functools.wraps(factory) + def wrapper(**kwargs: object) -> Any: + assert "strict" in kwargs + assert kwargs["strict"] + return factory(**kwargs) + + patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper) + patch2 = unittest.mock.patch( + f"pydantic.types.{factory.__name__}", new=wrapper + ) + patches.enter_context(patch1) + patches.enter_context(patch2) + yield + + +def run_test_snippet(source: str) -> None: + exec(textwrap.dedent(source), {}, {}) + + +class TestConstrainedTypesPatch(unittest.TestCase): + def test_expression_without_strict_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic import constr + constr() + """ + ) + + def test_called_as_module_attribute_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + import pydantic + pydantic.constr() + """ + ) + + def test_alternative_import_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic.types import constr + constr() + """ + ) + + def test_alternative_import_attribute_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + import pydantic.types + pydantic.types.constr() + """ + ) + + def test_kwarg_but_no_strict_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic import constr + constr(min_length=10) + """ + ) + + def test_kwarg_strict_False_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic import constr + constr(strict=False) + """ + ) + + def test_kwarg_strict_True_doesnt_raise(self): + with monkeypatch_pydantic(): + run_test_snippet( + """ + from pydantic import constr + constr(strict=True) + """ + ) + + def test_annotation_without_strict_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic import constr + x: constr() + """ + ) + + def test_field_annotation_without_strict_raises(self): + with monkeypatch_pydantic(), self.assertRaises(Exception): + run_test_snippet( + """ + from pydantic import BaseModel, constr + class C(BaseModel): + f: constr() + """ + ) + + +parser = argparse.ArgumentParser() +parser.add_argument("mode", choices=["lint", "test"]) + + +if __name__ == "__main__": + args = parser.parse_args(sys.argv[1:]) + if args.mode == "lint": + ... + elif args.mode == "test": + unittest.main(argv=sys.argv[:1]) diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 377348b107ea..68a37bcefbc6 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -106,4 +106,5 @@ isort "${files[@]}" python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh flake8 "${files[@]}" +./scripts-dev/check-pydantic-models.py mypy From 107a08f30356a8f80d30a1b55672e658eb6b05a6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 10 Aug 2022 19:42:57 +0100 Subject: [PATCH 02/11] wip 2 --- scripts-dev/check_pydantic_models.py | 132 ++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 21 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 4539db5d190f..86cfc3c1b002 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -12,16 +12,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +A script which enforces that Synapse always uses strict types when defining a Pydantic +model. + +Pydantic does not yet offer a strict mode (), but it is expected for V2. See + https://github.com/pydantic/pydantic/issues/1098 + https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode + +until then, this script stops us from introducing type coersion bugs like stringy power +levels. +""" import argparse import contextlib import functools +import importlib +import logging +import os +import pkgutil import sys import textwrap +import traceback import unittest.mock from contextlib import contextmanager -from typing import Generator, Any +from typing import Generator, TypeVar, Callable, Set from pydantic import confloat, conint, conbytes, constr +from typing_extensions import ParamSpec + +logger = logging.getLogger(__name__) CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG = [ constr, @@ -31,17 +50,31 @@ ] +P = ParamSpec("P") +R = TypeVar("R") + + +class NonStrictTypeError(Exception): + ... + + +def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(factory) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if "strict" not in kwargs: + raise NonStrictTypeError() + if not kwargs["strict"]: + raise NonStrictTypeError() + return factory(*args, **kwargs) + + return wrapper + + @contextmanager def monkeypatch_pydantic() -> Generator[None, None, None]: with contextlib.ExitStack() as patches: for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: - - @functools.wraps(factory) - def wrapper(**kwargs: object) -> Any: - assert "strict" in kwargs - assert kwargs["strict"] - return factory(**kwargs) - + wrapper = make_wrapper(factory) patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper) patch2 = unittest.mock.patch( f"pydantic.types.{factory.__name__}", new=wrapper @@ -51,13 +84,63 @@ def wrapper(**kwargs: object) -> Any: yield +def format_error(e: Exception) -> str: + frame_summary = traceback.extract_tb(e.__traceback__)[-2] + return traceback.format_list([frame_summary])[0].lstrip() + + +def lint() -> int: + failures = do_lint() + if failures: + print(f"Found {len(failures)} problem(s)") + for failure in sorted(failures): + print(failure) + return os.EX_DATAERR if failures else os.EX_OK + + +def do_lint() -> Set[str]: + failures = set() + + with monkeypatch_pydantic(): + try: + synapse = importlib.import_module("synapse") + except NonStrictTypeError as e: + logger.warning(f"Bad annotation from importing synapse") + failures.add(format_error(e)) + return failures + + try: + modules = list(pkgutil.walk_packages(synapse.__path__, "synapse.")) + except NonStrictTypeError as e: + logger.warning(f"Bad annotation when looking for modules to import") + failures.add(format_error(e)) + return failures + + for module in modules: + logger.debug("Importing %s", module.name) + try: + importlib.import_module(module.name) + except NonStrictTypeError as e: + logger.warning(f"Bad annotation from importing {module.name}") + failures.add(format_error(e)) + + return failures + + def run_test_snippet(source: str) -> None: - exec(textwrap.dedent(source), {}, {}) + # To emulate `source` being called at the top level of the module, + # the globals and locals we provide have to be the same mapping. + # + # > Remember that at the module level, globals and locals are the same dictionary. + # > If exec gets two separate objects as globals and locals, the code will be + # > executed as if it were embedded in a class definition. + g = l = {} + exec(textwrap.dedent(source), g, l) class TestConstrainedTypesPatch(unittest.TestCase): def test_expression_without_strict_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ from pydantic import constr @@ -66,7 +149,7 @@ def test_expression_without_strict_raises(self): ) def test_called_as_module_attribute_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ import pydantic @@ -75,7 +158,7 @@ def test_called_as_module_attribute_raises(self): ) def test_alternative_import_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ from pydantic.types import constr @@ -84,7 +167,7 @@ def test_alternative_import_raises(self): ) def test_alternative_import_attribute_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ import pydantic.types @@ -93,7 +176,7 @@ def test_alternative_import_attribute_raises(self): ) def test_kwarg_but_no_strict_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ from pydantic import constr @@ -102,7 +185,7 @@ def test_kwarg_but_no_strict_raises(self): ) def test_kwarg_strict_False_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ from pydantic import constr @@ -120,7 +203,7 @@ def test_kwarg_strict_True_doesnt_raise(self): ) def test_annotation_without_strict_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ from pydantic import constr @@ -129,23 +212,30 @@ def test_annotation_without_strict_raises(self): ) def test_field_annotation_without_strict_raises(self): - with monkeypatch_pydantic(), self.assertRaises(Exception): + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ - from pydantic import BaseModel, constr - class C(BaseModel): - f: constr() + from pydantic import BaseModel, conint + class C: + x: conint() """ ) parser = argparse.ArgumentParser() parser.add_argument("mode", choices=["lint", "test"]) +parser.add_argument("-v", "--verbose", action="store_true") if __name__ == "__main__": args = parser.parse_args(sys.argv[1:]) + logging.basicConfig( + format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + # suppress logs we don't care about + logging.getLogger("xmlschema").setLevel(logging.WARNING) if args.mode == "lint": - ... + sys.exit(lint()) elif args.mode == "test": unittest.main(argv=sys.argv[:1]) From aa6396efb0b922f6c02821f96630cec9fec58ec1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 11 Aug 2022 00:01:32 +0100 Subject: [PATCH 03/11] Mess with linter script --- .github/workflows/tests.yml | 15 ++++++++++++++- scripts-dev/lint.sh | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4bc29c820759..92dc58564eef 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -53,10 +53,23 @@ jobs: env: PULL_REQUEST_NUMBER: ${{ github.event.number }} + lint-pydantic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + ref: ${{ github.event.pull_request.head.sha }} + fetch-depth: 0 + - uses: matrix-org/setup-python-poetry@v1 + with: + python-version: ${{ matrix.python-version }} + extras: ${{ matrix.extras }} + - run: poetry run scripts-dev/check_pydantic_models.py + # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig, check-schema-delta] + needs: [lint, lint-crlf, lint-newsfile, lint-pydantic, check-sampleconfig, check-schema-delta] runs-on: ubuntu-latest steps: - run: "true" diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 68a37bcefbc6..bf900645b1f7 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -106,5 +106,5 @@ isort "${files[@]}" python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh flake8 "${files[@]}" -./scripts-dev/check-pydantic-models.py +./scripts-dev/check_pydantic_models.py lint mypy From 5fdc607c2804f8f5e3c7dafcccc4d74fd12dcdc0 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 11 Aug 2022 00:01:45 +0100 Subject: [PATCH 04/11] WIP check for the easy unintentional coercions --- scripts-dev/check_pydantic_models.py | 88 ++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 86cfc3c1b002..e48d1f5f602a 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -35,9 +35,10 @@ import traceback import unittest.mock from contextlib import contextmanager -from typing import Generator, TypeVar, Callable, Set +from typing import Callable, Generator, Set, Type, TypeVar -from pydantic import confloat, conint, conbytes, constr +from parameterized import parameterized +from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr from typing_extensions import ParamSpec logger = logging.getLogger(__name__) @@ -49,6 +50,14 @@ confloat, ] +TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [ + str, + bytes, + int, + float, + bool, +] + P = ParamSpec("P") R = TypeVar("R") @@ -70,9 +79,23 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return wrapper +class BaseModel(PydanticBaseModel): + @classmethod + def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs): + for field in cls.__fields__.values(): + if field.type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO: + raise NonStrictTypeError() + # breakpoint() + # print(cls, kwargs) + + @contextmanager def monkeypatch_pydantic() -> Generator[None, None, None]: with contextlib.ExitStack() as patches: + patch_basemodel1 = unittest.mock.patch("pydantic.BaseModel", new=BaseModel) + patch_basemodel2 = unittest.mock.patch("pydantic.main.BaseModel", new=BaseModel) + patches.enter_context(patch_basemodel1) + patches.enter_context(patch_basemodel2) for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: wrapper = make_wrapper(factory) patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper) @@ -105,14 +128,14 @@ def do_lint() -> Set[str]: try: synapse = importlib.import_module("synapse") except NonStrictTypeError as e: - logger.warning(f"Bad annotation from importing synapse") + logger.warning("Bad annotation found when importing synapse") failures.add(format_error(e)) return failures try: modules = list(pkgutil.walk_packages(synapse.__path__, "synapse.")) except NonStrictTypeError as e: - logger.warning(f"Bad annotation when looking for modules to import") + logger.warning("Bad annotation found when looking for modules to import") failures.add(format_error(e)) return failures @@ -121,7 +144,7 @@ def do_lint() -> Set[str]: try: importlib.import_module(module.name) except NonStrictTypeError as e: - logger.warning(f"Bad annotation from importing {module.name}") + logger.warning(f"Bad annotation found when importing {module.name}") failures.add(format_error(e)) return failures @@ -129,17 +152,17 @@ def do_lint() -> Set[str]: def run_test_snippet(source: str) -> None: # To emulate `source` being called at the top level of the module, - # the globals and locals we provide have to be the same mapping. + # the globals and locals we provide apparently have to be the same mapping. # # > Remember that at the module level, globals and locals are the same dictionary. # > If exec gets two separate objects as globals and locals, the code will be # > executed as if it were embedded in a class definition. - g = l = {} - exec(textwrap.dedent(source), g, l) + globals_ = locals_ = {} + exec(textwrap.dedent(source), globals_, locals_) class TestConstrainedTypesPatch(unittest.TestCase): - def test_expression_without_strict_raises(self): + def test_expression_without_strict_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -148,7 +171,7 @@ def test_expression_without_strict_raises(self): """ ) - def test_called_as_module_attribute_raises(self): + def test_called_as_module_attribute_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -157,7 +180,7 @@ def test_called_as_module_attribute_raises(self): """ ) - def test_alternative_import_raises(self): + def test_alternative_import_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -166,7 +189,7 @@ def test_alternative_import_raises(self): """ ) - def test_alternative_import_attribute_raises(self): + def test_alternative_import_attribute_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -175,7 +198,7 @@ def test_alternative_import_attribute_raises(self): """ ) - def test_kwarg_but_no_strict_raises(self): + def test_kwarg_but_no_strict_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -184,7 +207,7 @@ def test_kwarg_but_no_strict_raises(self): """ ) - def test_kwarg_strict_False_raises(self): + def test_kwarg_strict_False_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -193,7 +216,7 @@ def test_kwarg_strict_False_raises(self): """ ) - def test_kwarg_strict_True_doesnt_raise(self): + def test_kwarg_strict_True_doesnt_raise(self) -> None: with monkeypatch_pydantic(): run_test_snippet( """ @@ -202,7 +225,7 @@ def test_kwarg_strict_True_doesnt_raise(self): """ ) - def test_annotation_without_strict_raises(self): + def test_annotation_without_strict_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -211,7 +234,7 @@ def test_annotation_without_strict_raises(self): """ ) - def test_field_annotation_without_strict_raises(self): + def test_field_annotation_without_strict_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( """ @@ -222,6 +245,37 @@ class C: ) +class TestMetaclassPatch(unittest.TestCase): + @parameterized.expand( + [ + ("str",), + ("bytes"), + ("int",), + ("float",), + ("bool"), + ] + ) + def test_field_holding_plain_value_type_raises(self, type_name: str) -> None: + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + run_test_snippet( + f""" + from pydantic import BaseModel + class C(BaseModel): + f: {type_name} + """ + ) + + def test_field_holding_str_raises_with_alternative_import(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + run_test_snippet( + """ + from pydantic.main import BaseModel + class C(BaseModel): + f: str + """ + ) + + parser = argparse.ArgumentParser() parser.add_argument("mode", choices=["lint", "test"]) parser.add_argument("-v", "--verbose", action="store_true") From 07b6d75d018daf1d997161559ea0fd2fe84512e9 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 11 Aug 2022 01:38:58 +0100 Subject: [PATCH 05/11] Recursive type inspection + tidy up --- scripts-dev/check_pydantic_models.py | 109 ++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 20 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index e48d1f5f602a..4fa601a4c907 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -16,12 +16,13 @@ A script which enforces that Synapse always uses strict types when defining a Pydantic model. -Pydantic does not yet offer a strict mode (), but it is expected for V2. See +Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See + https://github.com/pydantic/pydantic/issues/1098 https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode -until then, this script stops us from introducing type coersion bugs like stringy power -levels. +until then, this script is a best effort to stop us from introducing type coersion bugs +(like the infamous stringy power levels fixed in room version 10). """ import argparse import contextlib @@ -35,15 +36,16 @@ import traceback import unittest.mock from contextlib import contextmanager -from typing import Callable, Generator, Set, Type, TypeVar +from typing import Any, Callable, Generator, Set, Type, TypeVar, List, Dict from parameterized import parameterized from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr +from pydantic.typing import get_args from typing_extensions import ParamSpec logger = logging.getLogger(__name__) -CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG = [ +CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [ constr, conbytes, conint, @@ -70,34 +72,55 @@ class NonStrictTypeError(Exception): def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - if "strict" not in kwargs: + # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 + if "strict" not in kwargs: # type: ignore[attr-defined] raise NonStrictTypeError() - if not kwargs["strict"]: + if not kwargs["strict"]: # type: ignore[index] raise NonStrictTypeError() return factory(*args, **kwargs) return wrapper -class BaseModel(PydanticBaseModel): +def field_type_unwanted(type_: Any) -> bool: + logger.debug("Is %s unwanted?") + if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO: + logger.debug("yes") + return True + logger.debug("Maybe. Subargs are %s", get_args(type_)) + rv = any(field_type_unwanted(t) for t in get_args(type_)) + logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not") + return rv + + +class PatchedBaseModel(PydanticBaseModel): + """Try to detect fields whose + + ModelField.type_ is presumably private, so this is likely to be very brittle. + """ + @classmethod - def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs): + def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object): for field in cls.__fields__.values(): - if field.type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO: + # Note that field.type_ and field.outer_type are computed based on the + # annotation type, see pydantic.fields.ModelField._type_analysis + if field_type_unwanted(field.outer_type_): raise NonStrictTypeError() - # breakpoint() - # print(cls, kwargs) @contextmanager def monkeypatch_pydantic() -> Generator[None, None, None]: with contextlib.ExitStack() as patches: - patch_basemodel1 = unittest.mock.patch("pydantic.BaseModel", new=BaseModel) - patch_basemodel2 = unittest.mock.patch("pydantic.main.BaseModel", new=BaseModel) + patch_basemodel1 = unittest.mock.patch( + "pydantic.BaseModel", new=PatchedBaseModel + ) + patch_basemodel2 = unittest.mock.patch( + "pydantic.main.BaseModel", new=PatchedBaseModel + ) patches.enter_context(patch_basemodel1) patches.enter_context(patch_basemodel2) for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: - wrapper = make_wrapper(factory) + wrapper: Callable = make_wrapper(factory) patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper) patch2 = unittest.mock.patch( f"pydantic.types.{factory.__name__}", new=wrapper @@ -157,6 +180,8 @@ def run_test_snippet(source: str) -> None: # > Remember that at the module level, globals and locals are the same dictionary. # > If exec gets two separate objects as globals and locals, the code will be # > executed as if it were embedded in a class definition. + globals_: Dict[str, object] + locals_: Dict[str, object] globals_ = locals_ = {} exec(textwrap.dedent(source), globals_, locals_) @@ -180,6 +205,15 @@ def test_called_as_module_attribute_raises(self) -> None: """ ) + def test_wildcard_import_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + run_test_snippet( + """ + from pydantic import * + constr() + """ + ) + def test_alternative_import_raises(self) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( @@ -245,7 +279,7 @@ class C: ) -class TestMetaclassPatch(unittest.TestCase): +class TestFieldTypeInspection(unittest.TestCase): @parameterized.expand( [ ("str",), @@ -253,15 +287,50 @@ class TestMetaclassPatch(unittest.TestCase): ("int",), ("float",), ("bool"), + ("Optional[str]",), + ("Union[None, str]",), + ("List[str]",), + ("List[List[str]]",), + ("Dict[StrictStr, str]",), + ("Dict[str, StrictStr]",), + ("TypedDict('D', x=int)",), ] ) - def test_field_holding_plain_value_type_raises(self, type_name: str) -> None: + def test_field_holding_unwanted_type_raises(self, annotation: str) -> None: with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): run_test_snippet( f""" - from pydantic import BaseModel + from typing import * + from pydantic import * + class C(BaseModel): + f: {annotation} + """ + ) + + @parameterized.expand( + [ + ("StrictStr",), + ("StrictBytes"), + ("StrictInt",), + ("StrictFloat",), + ("StrictBool"), + ("constr(strict=True, min_length=10)",), + ("Optional[StrictStr]",), + ("Union[None, StrictStr]",), + ("List[StrictStr]",), + ("List[List[StrictStr]]",), + ("Dict[StrictStr, StrictStr]",), + ("TypedDict('D', x=StrictInt)",), + ] + ) + def test_field_holding_accepted_type_raises(self, annotation: str) -> None: + with monkeypatch_pydantic(): + run_test_snippet( + f""" + from typing import * + from pydantic import * class C(BaseModel): - f: {type_name} + f: {annotation} """ ) @@ -277,7 +346,7 @@ class C(BaseModel): parser = argparse.ArgumentParser() -parser.add_argument("mode", choices=["lint", "test"]) +parser.add_argument("mode", choices=["lint", "test"], default="lint") parser.add_argument("-v", "--verbose", action="store_true") From 1db56bc2f2a350e20b13dcd59970738ca9399541 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 11 Aug 2022 01:51:09 +0100 Subject: [PATCH 06/11] docstrings --- scripts-dev/check_pydantic_models.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 4fa601a4c907..71a715a56312 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -66,10 +66,12 @@ class NonStrictTypeError(Exception): - ... + """Dummy exception. Allows us to detect unwanted types during a module import.""" def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: + """We patch `constr` and friends with wrappers that enforce strict=True. """ + @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 @@ -83,6 +85,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: def field_type_unwanted(type_: Any) -> bool: + """Very rough attempt to detect if a type is unwanted as a Pydantic annotation. + + At present, we exclude types which will coerce, or any generic type involving types + which will coerce.""" logger.debug("Is %s unwanted?") if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO: logger.debug("yes") @@ -94,9 +100,11 @@ def field_type_unwanted(type_: Any) -> bool: class PatchedBaseModel(PydanticBaseModel): - """Try to detect fields whose + """A patched version of BaseModel that inspects fields after models are defined. + + We complain loudly if we see an unwanted type. - ModelField.type_ is presumably private, so this is likely to be very brittle. + Beware: ModelField.type_ is presumably private; this is likely to be very brittle. """ @classmethod @@ -110,6 +118,12 @@ def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object): @contextmanager def monkeypatch_pydantic() -> Generator[None, None, None]: + """Patch pydantic with our snooping versions of BaseModel and the con* functions. + + Most Synapse code ought to import the patched objects directly from `pydantic`. + But we include their containing models `pydantic.main` and `pydantic.types` for + completeness. + """ with contextlib.ExitStack() as patches: patch_basemodel1 = unittest.mock.patch( "pydantic.BaseModel", new=PatchedBaseModel @@ -130,12 +144,16 @@ def monkeypatch_pydantic() -> Generator[None, None, None]: yield -def format_error(e: Exception) -> str: +def format_error(e: NonStrictTypeError) -> str: + """Work out which line of code caused e. Format the line in a human-friendly way.""" frame_summary = traceback.extract_tb(e.__traceback__)[-2] return traceback.format_list([frame_summary])[0].lstrip() def lint() -> int: + """Try to import all of Synapse and see if we spot any Pydantic type coercions. + + Print any problems, then return a status code suitable for sys.exit.""" failures = do_lint() if failures: print(f"Found {len(failures)} problem(s)") @@ -145,6 +163,7 @@ def lint() -> int: def do_lint() -> Set[str]: + """Try to import all of Synapse and see if we spot any Pydantic type coercions.""" failures = set() with monkeypatch_pydantic(): @@ -174,6 +193,7 @@ def do_lint() -> Set[str]: def run_test_snippet(source: str) -> None: + """Exec a snippet of source code in an isolated environment.""" # To emulate `source` being called at the top level of the module, # the globals and locals we provide apparently have to be the same mapping. # From a05671893331adef9a55b5af81568a4f7249b609 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 11 Aug 2022 02:34:36 +0100 Subject: [PATCH 07/11] Improve error messages --- scripts-dev/check_pydantic_models.py | 95 +++++++++++++++++++--------- 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 71a715a56312..4385295e8a6c 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -65,20 +65,34 @@ R = TypeVar("R") -class NonStrictTypeError(Exception): +class ModelCheckerException(Exception): """Dummy exception. Allows us to detect unwanted types during a module import.""" +class MissingStrictInConstrainedTypeException(ModelCheckerException): + factory_name: str + + def __init__(self, factory_name: str): + self.factory_name = factory_name + + +class FieldHasUnwantedTypeException(ModelCheckerException): + message: str + + def __init__(self, message: str): + self.message = message + + def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: - """We patch `constr` and friends with wrappers that enforce strict=True. """ + """We patch `constr` and friends with wrappers that enforce strict=True.""" @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 if "strict" not in kwargs: # type: ignore[attr-defined] - raise NonStrictTypeError() + raise MissingStrictInConstrainedTypeException(factory.__name__) if not kwargs["strict"]: # type: ignore[index] - raise NonStrictTypeError() + raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) return wrapper @@ -113,18 +127,25 @@ def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object): # Note that field.type_ and field.outer_type are computed based on the # annotation type, see pydantic.fields.ModelField._type_analysis if field_type_unwanted(field.outer_type_): - raise NonStrictTypeError() + # TODO: this only reports the first bad field. Can we find all bad ones + # and report them all? + raise FieldHasUnwantedTypeException( + f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' " + f"with unwanted type `{field.outer_type_}`" + ) @contextmanager def monkeypatch_pydantic() -> Generator[None, None, None]: """Patch pydantic with our snooping versions of BaseModel and the con* functions. - Most Synapse code ought to import the patched objects directly from `pydantic`. - But we include their containing models `pydantic.main` and `pydantic.types` for - completeness. + If the snooping functions see something they don't like, they'll raise a + ModelCheckingException instance. """ with contextlib.ExitStack() as patches: + # Most Synapse code ought to import the patched objects directly from + # `pydantic`. But we also patch their containing modules `pydantic.main` and + # `pydantic.types` for completeness. patch_basemodel1 = unittest.mock.patch( "pydantic.BaseModel", new=PatchedBaseModel ) @@ -144,10 +165,20 @@ def monkeypatch_pydantic() -> Generator[None, None, None]: yield -def format_error(e: NonStrictTypeError) -> str: +def format_model_checker_exception(e: ModelCheckerException) -> str: """Work out which line of code caused e. Format the line in a human-friendly way.""" - frame_summary = traceback.extract_tb(e.__traceback__)[-2] - return traceback.format_list([frame_summary])[0].lstrip() + # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the + # patches of constr() etc, and instead inspect fields to look for ConstrainedStr + # with strict=False? There is some difficulty with the inheritance hierarchy + # because StrictStr < ConstrainedStr < str. + if isinstance(e, FieldHasUnwantedTypeException): + return e.message + elif isinstance(e, MissingStrictInConstrainedTypeException): + frame_summary = traceback.extract_tb(e.__traceback__)[-2] + return ( + f"Missing `strict=True` from {e.factory_name}() call \n" + + traceback.format_list([frame_summary])[0].lstrip() + ) def lint() -> int: @@ -168,26 +199,30 @@ def do_lint() -> Set[str]: with monkeypatch_pydantic(): try: - synapse = importlib.import_module("synapse") - except NonStrictTypeError as e: + # TODO: make "synapse" an argument so we can target this script at + # a subpackage + module = importlib.import_module("synapse") + except ModelCheckerException as e: logger.warning("Bad annotation found when importing synapse") - failures.add(format_error(e)) + failures.add(format_model_checker_exception(e)) return failures try: - modules = list(pkgutil.walk_packages(synapse.__path__, "synapse.")) - except NonStrictTypeError as e: + modules = list( + pkgutil.walk_packages(module.__path__, f"{module.__name__}.") + ) + except ModelCheckerException as e: logger.warning("Bad annotation found when looking for modules to import") - failures.add(format_error(e)) + failures.add(format_model_checker_exception(e)) return failures for module in modules: logger.debug("Importing %s", module.name) try: importlib.import_module(module.name) - except NonStrictTypeError as e: + except ModelCheckerException as e: logger.warning(f"Bad annotation found when importing {module.name}") - failures.add(format_error(e)) + failures.add(format_model_checker_exception(e)) return failures @@ -208,7 +243,7 @@ def run_test_snippet(source: str) -> None: class TestConstrainedTypesPatch(unittest.TestCase): def test_expression_without_strict_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import constr @@ -217,7 +252,7 @@ def test_expression_without_strict_raises(self) -> None: ) def test_called_as_module_attribute_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ import pydantic @@ -226,7 +261,7 @@ def test_called_as_module_attribute_raises(self) -> None: ) def test_wildcard_import_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import * @@ -235,7 +270,7 @@ def test_wildcard_import_raises(self) -> None: ) def test_alternative_import_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic.types import constr @@ -244,7 +279,7 @@ def test_alternative_import_raises(self) -> None: ) def test_alternative_import_attribute_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ import pydantic.types @@ -253,7 +288,7 @@ def test_alternative_import_attribute_raises(self) -> None: ) def test_kwarg_but_no_strict_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import constr @@ -262,7 +297,7 @@ def test_kwarg_but_no_strict_raises(self) -> None: ) def test_kwarg_strict_False_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import constr @@ -280,7 +315,7 @@ def test_kwarg_strict_True_doesnt_raise(self) -> None: ) def test_annotation_without_strict_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import constr @@ -289,7 +324,7 @@ def test_annotation_without_strict_raises(self) -> None: ) def test_field_annotation_without_strict_raises(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic import BaseModel, conint @@ -317,7 +352,7 @@ class TestFieldTypeInspection(unittest.TestCase): ] ) def test_field_holding_unwanted_type_raises(self, annotation: str) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( f""" from typing import * @@ -355,7 +390,7 @@ class C(BaseModel): ) def test_field_holding_str_raises_with_alternative_import(self) -> None: - with monkeypatch_pydantic(), self.assertRaises(NonStrictTypeError): + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): run_test_snippet( """ from pydantic.main import BaseModel From fa62489b875fc7bdf47793f6831b4d0e3691c049 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 15 Aug 2022 21:04:45 +0100 Subject: [PATCH 08/11] lint again --- scripts-dev/check_pydantic_models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 4385295e8a6c..605dda63542f 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -36,7 +36,7 @@ import traceback import unittest.mock from contextlib import contextmanager -from typing import Any, Callable, Generator, Set, Type, TypeVar, List, Dict +from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar from parameterized import parameterized from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr @@ -179,6 +179,8 @@ def format_model_checker_exception(e: ModelCheckerException) -> str: f"Missing `strict=True` from {e.factory_name}() call \n" + traceback.format_list([frame_summary])[0].lstrip() ) + else: + raise ValueError(f"Unknown exception {e}") from e def lint() -> int: @@ -208,7 +210,7 @@ def do_lint() -> Set[str]: return failures try: - modules = list( + module_infos = list( pkgutil.walk_packages(module.__path__, f"{module.__name__}.") ) except ModelCheckerException as e: @@ -216,12 +218,14 @@ def do_lint() -> Set[str]: failures.add(format_model_checker_exception(e)) return failures - for module in modules: - logger.debug("Importing %s", module.name) + for module_info in module_infos: + logger.debug("Importing %s", module_info.name) try: - importlib.import_module(module.name) + importlib.import_module(module_info.name) except ModelCheckerException as e: - logger.warning(f"Bad annotation found when importing {module.name}") + logger.warning( + f"Bad annotation found when importing {module_info.name}" + ) failures.add(format_model_checker_exception(e)) return failures From 3aa43857e163299be7267f889085c5fb3626fa0b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 15 Aug 2022 21:23:07 +0100 Subject: [PATCH 09/11] debug --- scripts-dev/check_pydantic_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 605dda63542f..d0fb811bdb5e 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -200,6 +200,7 @@ def do_lint() -> Set[str]: failures = set() with monkeypatch_pydantic(): + logger.debug("Importing synapse") try: # TODO: make "synapse" an argument so we can target this script at # a subpackage @@ -210,6 +211,7 @@ def do_lint() -> Set[str]: return failures try: + logger.debug("Fetching subpackages") module_infos = list( pkgutil.walk_packages(module.__path__, f"{module.__name__}.") ) @@ -382,7 +384,7 @@ class C(BaseModel): ("TypedDict('D', x=StrictInt)",), ] ) - def test_field_holding_accepted_type_raises(self, annotation: str) -> None: + def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None: with monkeypatch_pydantic(): run_test_snippet( f""" @@ -405,7 +407,7 @@ class C(BaseModel): parser = argparse.ArgumentParser() -parser.add_argument("mode", choices=["lint", "test"], default="lint") +parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?") parser.add_argument("-v", "--verbose", action="store_true") From d29983ac60e9b27f47b2216512667cdb6cde5dcd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 15 Aug 2022 21:26:33 +0100 Subject: [PATCH 10/11] Changelog --- changelog.d/13502.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13502.misc diff --git a/changelog.d/13502.misc b/changelog.d/13502.misc new file mode 100644 index 000000000000..ed6832996e06 --- /dev/null +++ b/changelog.d/13502.misc @@ -0,0 +1 @@ +Add a linter script which will reject non-strict types in Pydantic models. From efea8cf2e33ac234ce9a966d26cdcf3e175d743a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 15 Aug 2022 21:29:32 +0100 Subject: [PATCH 11/11] Fix CI job --- .github/workflows/tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 92dc58564eef..144cb9ffaac2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,8 +62,7 @@ jobs: fetch-depth: 0 - uses: matrix-org/setup-python-poetry@v1 with: - python-version: ${{ matrix.python-version }} - extras: ${{ matrix.extras }} + extras: "all" - run: poetry run scripts-dev/check_pydantic_models.py # Dummy step to gate other tests on without repeating the whole list