diff --git a/pytype/analyze.py b/pytype/analyze.py index 72791fa62..43e2adb3a 100644 --- a/pytype/analyze.py +++ b/pytype/analyze.py @@ -12,6 +12,7 @@ from pytype import function from pytype import metrics from pytype import output +from pytype import special_builtins from pytype import state as frame_state from pytype import vm from pytype.overlays import typing_overlay @@ -175,6 +176,8 @@ def maybe_analyze_method(self, node, val, cls=None): else: for f in method.iter_signature_functions(): node, args = self.create_method_arguments(node, f) + if f.is_classmethod and cls: + args = self._maybe_fix_classmethod_cls_arg(node, cls, f, args) node, _ = self.call_function_with_args(node, val, args) return node @@ -227,7 +230,11 @@ def analyze_method_var(self, node0, name, var, cls=None): def bind_method(self, node, name, methodvar, instance_var): bound = self.program.NewVariable() for m in methodvar.Data(node): - is_cls = False + if isinstance(m, special_builtins.ClassMethodInstance): + m = m.func.data[0] + is_cls = True + else: + is_cls = (m.isinstance_InterpreterFunction() and m.is_classmethod) bound.AddBinding(m.property_get(instance_var, is_cls), [], node) return bound diff --git a/pytype/tests/test_classes.py b/pytype/tests/test_classes.py index d29036090..963eda595 100644 --- a/pytype/tests/test_classes.py +++ b/pytype/tests/test_classes.py @@ -165,7 +165,6 @@ class Foo(object): def bar(cls) -> None: ... """) - @test_base.skip("Temporary rollback") def test_factory_classmethod(self): ty = self.Infer(""" class Foo(object): @@ -181,7 +180,6 @@ class Foo: def factory(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... """) - @test_base.skip("Temporary rollback") def test_classmethod_return_inference(self): ty = self.Infer(""" class Foo(object): diff --git a/pytype/tests/test_cmp.py b/pytype/tests/test_cmp.py index 41124ffa0..229bc45e3 100644 --- a/pytype/tests/test_cmp.py +++ b/pytype/tests/test_cmp.py @@ -155,7 +155,6 @@ class Foo: def __new__(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... """) - @test_base.skip("Temporary rollback") def test_class_factory(self): # The assert should not block inference of the return type, since cls could # be a subclass of Foo diff --git a/pytype/vm.py b/pytype/vm.py index ee2166bb0..731ca96dd 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1604,7 +1604,7 @@ def _is_classmethod_cls_arg(self, var): return False func = self.frame.func.data - if func.name.rsplit(".")[-1] == "__new__": + if func.is_classmethod or func.name.rsplit(".")[-1] == "__new__": is_cls = not set(var.data) - set(self.frame.first_posarg.data) return is_cls return False