From c5e4e0f5e818639c573fb7ab4f33eb237754c15c Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Sun, 11 Apr 2021 16:00:37 -0700 Subject: [PATCH] Add plugin for functools.total_ordering (#7831) --- mypy/plugins/default.py | 4 + mypy/plugins/functools.py | 105 +++++++++++++++++++++++++++ mypy/test/testcheck.py | 1 + test-data/unit/check-functools.test | 109 ++++++++++++++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 mypy/plugins/functools.py create mode 100644 test-data/unit/check-functools.test diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 3621e4e4de7a..552a52c5c860 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -93,6 +93,7 @@ def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: from mypy.plugins import attrs from mypy.plugins import dataclasses + from mypy.plugins import functools if fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback @@ -114,6 +115,9 @@ def get_class_decorator_hook(self, fullname: str ) elif fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback + elif fullname in functools.functools_total_ordering_makers: + return functools.functools_total_ordering_maker_callback + return None diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py new file mode 100644 index 000000000000..d2905c06c2e8 --- /dev/null +++ b/mypy/plugins/functools.py @@ -0,0 +1,105 @@ +"""Plugin for supporting the functools standard library module.""" +from typing import Dict, NamedTuple, Optional + +import mypy.plugin +from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType + + +functools_total_ordering_makers = { + 'functools.total_ordering', +} + +_ORDERING_METHODS = { + '__lt__', + '__le__', + '__gt__', + '__ge__', +} + + +_MethodInfo = NamedTuple('_MethodInfo', [('is_static', bool), ('type', CallableType)]) + + +def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, + auto_attribs_default: bool = False) -> None: + """Add dunder methods to classes decorated with functools.total_ordering.""" + if ctx.api.options.python_version < (3,): + ctx.api.fail('"functools.total_ordering" is not supported in Python 2', ctx.reason) + return + + comparison_methods = _analyze_class(ctx) + if not comparison_methods: + ctx.api.fail( + 'No ordering operation defined when using "functools.total_ordering": < > <= >=', + ctx.reason) + return + + # prefer __lt__ to __le__ to __gt__ to __ge__ + root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k)) + root_method = comparison_methods[root] + if not root_method: + # None of the defined comparison methods can be analysed + return + + other_type = _find_other_type(root_method) + bool_type = ctx.api.named_type('__builtins__.bool') + ret_type = bool_type # type: Type + if root_method.type.ret_type != ctx.api.named_type('__builtins__.bool'): + proper_ret_type = get_proper_type(root_method.type.ret_type) + if not (isinstance(proper_ret_type, UnboundType) + and proper_ret_type.name.split('.')[-1] == 'bool'): + ret_type = AnyType(TypeOfAny.implementation_artifact) + for additional_op in _ORDERING_METHODS: + # Either the method is not implemented + # or has an unknown signature that we can now extrapolate. + if not comparison_methods.get(additional_op): + args = [Argument(Var('other', other_type), other_type, None, ARG_POS)] + add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type) + + +def _find_other_type(method: _MethodInfo) -> Type: + """Find the type of the ``other`` argument in a comparison method.""" + first_arg_pos = 0 if method.is_static else 1 + cur_pos_arg = 0 + other_arg = None + for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types): + if arg_kind in (ARG_POS, ARG_OPT): + if cur_pos_arg == first_arg_pos: + other_arg = arg_type + break + + cur_pos_arg += 1 + elif arg_kind != ARG_STAR2: + other_arg = arg_type + break + + if other_arg is None: + return AnyType(TypeOfAny.implementation_artifact) + + return other_arg + + +def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> Dict[str, Optional[_MethodInfo]]: + """Analyze the class body, its parents, and return the comparison methods found.""" + # Traverse the MRO and collect ordering methods. + comparison_methods = {} # type: Dict[str, Optional[_MethodInfo]] + # Skip object because total_ordering does not use methods from object + for cls in ctx.cls.info.mro[:-1]: + for name in _ORDERING_METHODS: + if name in cls.names and name not in comparison_methods: + node = cls.names[name].node + if isinstance(node, FuncItem) and isinstance(node.type, CallableType): + comparison_methods[name] = _MethodInfo(node.is_static, node.type) + continue + + if isinstance(node, Var): + proper_type = get_proper_type(node.type) + if isinstance(proper_type, CallableType): + comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type) + continue + + comparison_methods[name] = None + + return comparison_methods diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 51f5d71c12ad..f6b36c376180 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -94,6 +94,7 @@ 'check-parameter-specification.test', 'check-generic-alias.test', 'check-typeguard.test', + 'check-functools.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test new file mode 100644 index 000000000000..416006591425 --- /dev/null +++ b/test-data/unit/check-functools.test @@ -0,0 +1,109 @@ +[case testTotalOrderingEqLt] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other: object) -> bool: + return False + + def __lt__(self, other: "Ord") -> bool: + return False + +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool" + +Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int") +Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") +Ord() == 1 +Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int") +Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingLambda] +from functools import total_ordering +from typing import Any, Callable + +@total_ordering +class Ord: + __eq__: Callable[[Any, object], bool] = lambda self, other: False + __lt__: Callable[[Any, "Ord"], bool] = lambda self, other: False + +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool" + +Ord() < 1 # E: Argument 1 has incompatible type "int"; expected "Ord" +Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int") +Ord() == 1 +Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int") +Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingNonCallable] +from functools import total_ordering + +@total_ordering +class Ord(object): + def __eq__(self, other: object) -> bool: + return False + + __lt__ = 5 + +Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord") +Ord() > Ord() # E: "int" not callable +Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord") + +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingReturnNotBool] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other: object) -> bool: + return False + + def __lt__(self, other: "Ord") -> str: + return "blah" + +reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.str" +reveal_type(Ord() <= Ord()) # N: Revealed type is "Any" +reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool" +reveal_type(Ord() > Ord()) # N: Revealed type is "Any" +reveal_type(Ord() >= Ord()) # N: Revealed type is "Any" + +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi] + +[case testTotalOrderingAllowsAny] +from functools import total_ordering + +@total_ordering +class Ord: + def __eq__(self, other): + return False + + def __gt__(self, other): + return False + +reveal_type(Ord() < Ord()) # N: Revealed type is "Any" +Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord") +reveal_type(Ord() == Ord()) # N: Revealed type is "Any" +reveal_type(Ord() > Ord()) # N: Revealed type is "Any" +Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord") + +Ord() < 1 # E: Unsupported left operand type for < ("Ord") +Ord() <= 1 # E: Unsupported left operand type for <= ("Ord") +Ord() == 1 +Ord() > 1 +Ord() >= 1 # E: Unsupported left operand type for >= ("Ord") +[builtins fixtures/ops.pyi] +[builtins fixtures/dict.pyi]